U
    rh                     @  sR   d dl mZ d dlZd dlZd dlmZ d dlmZ d
ddZ	G dd	 d	eZ
dS )    )annotationsN)Base)expectmeanc                 C  s,  | j }t|dkrtd|j }|d }|d }	tj| ddd}
t| |
 }|tj|ddd }t|}d }|dkrt|}d }|d k	rtj	|tj
|tjddd}|d k	rt||kd|jtjd}n$|d k	rt||kddjtjd}t|d	kr |||	d
f}||d
f}|j d }tj||ftjd}t|D ]J}t|D ]:}|| | |krR|| || |  |  || |< qRqF|}t|d	kr||}|d k	r|| }|dkr| |  }|dkr||fS |S |dkrt|}n|dkrt|}|r(||fS |S )N   zUnsupported shaper   T)ZaxisZkeepdimsZdtypeZclip)mode      r   sum)shapelenRuntimeErrornpmaxexpr   logcopyZtakearrayZint32whereastypefloat32Zreshapezerosranger   )xtargetweight	reductionignore_indexget_log_probZinput_shapeZtarget_shapeNCZmax_xZexp_xpinplog_probZgather_weightDZneg_gather_element_inputidloss r*   S/tmp/pip-unpacked-wheel-xnis5xre/onnx/backend/test/case/node/softmaxcrossentropy.pysoftmaxcrossentropy   s`    


*






r,   c                   @  sp  e Zd ZeddddZeddddZeddddZeddd	d
ZeddddZeddddZ	eddddZ
eddddZeddddZeddddZeddddZeddddZeddddZeddddZedddd Zeddd!d"Zeddd#d$Zeddd%d&Zeddd'd(Zeddd)d*Zeddd+d,Zeddd-d.Zeddd/d0Zeddd1d2Zeddd3d4Zeddd5d6Zeddd7d8Zeddd9d:Zeddd;d<Z eddd=d>Z!eddd?d@Z"edddAdBZ#edddCdDZ$edddEdFZ%dGS )HSoftmaxCrossEntropyLossNone)returnc                  C  s   d} t jjdddgdg| d}tjd tjdd	tj}tjj	dd	d
dtj
}t||dd}t|||g|gdd d S )Nnoner-   r   yzinputsoutputsr   r   r	      r	   highsizer   Ztest_sce_noner4   r5   nameonnxhelper	make_noder   randomseedrandr   r   randintint64r,   r   r   noder   labelsscer*   r*   r+   export_softmaxcrossentropy_none^   s    z7SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_nonec                  C  s   d} t jjdddgddg| d}tjd tjd	d
tj}tjj	dd
ddtj
}t||ddd\}}t|||g||gdd d S )Nr0   r-   r   r1   r2   r%   r3   r   r	   r6   r7   r8   Tr   r    Ztest_sce_none_log_probr<   r>   r   rH   r   rI   r)   r%   r*   r*   r+   (export_softmaxcrossentropy_none_log_probv   s,       
z@SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_none_log_probc                  C  s   d} t jjddddgdg| d}tjd tjd	d
tj}tjj	dd
ddtj
}tjdddddgtjd}t|||dd}t||||g|gdd d S )Nr0   r-   r   r1   wr2   r3   r   r	   r6   r7   r8   ?ffffff?皙?r   r   r   Ztest_sce_none_weightsr<   r?   r@   rA   r   rB   rC   rD   r   r   rE   rF   r   r,   r   r   rH   r   rI   weightsrJ   r*   r*   r+   'export_softmaxcrossentropy_none_weights   s$    z?SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_none_weightsc                  C  s   d} t jjddddgddg| d}tjd	 tjd
dtj}tjj	d	dddtj
}tjdddddgtjd}t|||ddd\}}t||||g||gdd d S )Nr0   r-   r   r1   rO   r2   r%   r3   r   r	   r6   r7   r8   rP   rQ   rR   r   Tr   r   r    Ztest_sce_none_weights_log_probr<   rT   r   rH   r   rI   rV   r)   r%   r*   r*   r+   0export_softmaxcrossentropy_none_weights_log_prob   s0        
zHSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_none_weights_log_probc                  C  s   d} t jjdddgdg| d}tjd tjdd	tj}tjj	dd	d
dtj
}t||dd}t|||g|gdd d S )Nr   r-   r   r1   r2   r3   r   r	   r6   r7   r8   r;   Ztest_sce_sumr<   r>   rG   r*   r*   r+   export_softmaxcrossentropy_sum   s    z6SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_sumc                  C  s   d} t jjdddgddg| d}tjd tjd	d
tj}tjj	dd
ddtj
}t||ddd\}}t|||g||gdd d S )Nr   r-   r   r1   r2   r%   r3   r   r	   r6   r7   r8   TrL   Ztest_sce_sum_log_probr<   r>   rM   r*   r*   r+   'export_softmaxcrossentropy_sum_log_prob   s,       
z?SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_sum_log_probc                  C  s~   d} t jjdddgdg| d}tjd tjdd	tj}tjj	dd	d
dtj
}t||}t|||g|gdd d S )Nr   r-   r   r1   r2   r3   r   r	   r6   r7   r8   Ztest_sce_meanr<   r>   rG   r*   r*   r+   export_softmaxcrossentropy_mean
  s    
z7SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_meanc                  C  s   d} t jjdddgddg| d}tjd tjd	d
tj}tjj	dd
ddtj
}t||dd\}}t|||g||gdd d S )Nr   r-   r   r1   r2   r%   r3   r   r	   r6   r7   r8   Tr    Ztest_sce_mean_log_probr<   r>   rM   r*   r*   r+   (export_softmaxcrossentropy_mean_log_prob"  s"    z@SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_log_probc                  C  s   d} t jjdddgdg| d}tjd tjdd	d
tj}tjj	dd	ddtj
}t||}t|||g|gdd d S )Nr   r-   r   r1   r2   r3   r   r	   r6   r   r	   r   r8   Ztest_sce_mean_3dr<   r>   )r   rH   r   r1   rJ   r*   r*   r+   "export_softmaxcrossentropy_mean_3d?  s    
z:SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_3dc                  C  s   d} t jjdddgddg| d}tjd tjd	d
dtj}tjj	dd
ddtj
}t||dd\}}t|||g||gdd d S )Nr   r-   r   r1   r2   r%   r3   r   r	   r6   r   r`   r8   Tr^   Ztest_sce_mean_3d_log_probr<   r>   )r   rH   r   r1   r)   r%   r*   r*   r+   +export_softmaxcrossentropy_mean_3d_log_probW  s"    zCSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_3d_log_probc                  C  s   d} t jjddddgdg| d}tjd tjd	d
tj}tjj	dd
ddtj
}tjdddddgtjd}t|||d}t||||g|gdd d S )Nr   r-   r   r1   rO   r2   r3   r   r	   r6   r7   r8   rP   rQ   rR   r   )r   Ztest_sce_mean_weightr<   rT   rU   r*   r*   r+   'export_softmaxcrossentropy_mean_weightst  s$    z?SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weightsc                  C  s   d} t jjddddgddg| d}tjd	 tjd
dtj}tjj	d	dddtj
}tjdddddgtjd}t|||dd\}}t||||g||gdd d S )Nr   r-   r   r1   rO   r2   r%   r3   r   r	   r6   r7   r8   rP   rQ   rR   r   T)r   r    Ztest_sce_mean_weight_log_probr<   rT   rY   r*   r*   r+   0export_softmaxcrossentropy_mean_weights_log_prob  s.       
zHSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_log_probc                  C  s   d} t d}tjjddddgdg| |d}t jd t jd	d
t j	}t jj
dd
ddt j}t d|d< t jdddddgt j	d}t||||d}t||||g|gdd d S )Nr   r   r-   r   r1   rO   r2   r4   r5   r   r   r	   r6   r7   r8   rP   rQ   rR   r   r   r   Ztest_sce_mean_weight_iir<   r   rF   r?   r@   rA   rB   rC   rD   r   r   rE   r   r,   r   r   r   rH   r   rI   rV   rJ   r*   r*   r+   *export_softmaxcrossentropy_mean_weights_ii  s*    
	zBSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_iic                  C  s   d} t d}tjjddddgddg| |d	}t jd t jd
dt j	}t jj
ddddt j}t d|d< t jdddddgt j	d}t||||dd\}}t||||g||gdd d S )Nr   r   r-   r   r1   rO   r2   r%   re   r	   r6   r7   r8   rP   rQ   rR   r   Tr   r   r    Z test_sce_mean_weight_ii_log_probr<   rg   r   r   rH   r   rI   rV   r)   r%   r*   r*   r+   3export_softmaxcrossentropy_mean_weights_ii_log_prob  s6    
	    
zKSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_ii_log_probc                  C  s   d} t d}tjjdddgdg| |d}t jd t jd	d
t j	}t jj
dd
ddt j}t d|d< t|||d}t|||g|gdd d S )Nr   r   r-   r   r1   r2   re   r   r	   r6   r7   r8   r   Ztest_sce_mean_no_weight_iir<   r   rF   r?   r@   rA   rB   rC   rD   r   r   rE   r,   r   r   r   rH   r   rI   rJ   r*   r*   r+   -export_softmaxcrossentropy_mean_no_weights_ii  s(    
	   zESoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_iic                  C  s   d} t d}tjjdddgddg| |d}t jd	 t jd
dt j	}t jj
d	dddt j}t d|d	< t|||dd\}}t|||g||gdd d S )Nr   r   r-   r   r1   r2   r%   re   r   r	   r6   r7   r8   Tr   r    Z#test_sce_mean_no_weight_ii_log_probr<   rn   r   r   rH   r   rI   r)   r%   r*   r*   r+   6export_softmaxcrossentropy_mean_no_weights_ii_log_prob  s2    
	   
zNSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_ii_log_probc                  C  s   d} t d}tjjddddgdg| |d}t jd	 t jd
ddt j	}t jj
d	dddt j}t d|d	 d	< t jdddddgt j	d}t||||d}t||||g|gdd d S )Nr   r   r-   r   r1   rO   r2   re   r   r	   r6   r   r`   r8   皙?333333?333333?皙?      ?r   rf   Ztest_sce_mean_weight_ii_3dr<   rg   rh   r*   r*   r+   -export_softmaxcrossentropy_mean_weights_ii_3d5  s*    
	zESoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_ii_3dc                  C  s   d} t d}tjjddddgddg| |d	}t jd
 t jdddt j	}t jj
d
dddt j}t d|d
 d
< t jdddddgt j	d}t||||dd\}}t||||g||gdd d S )Nr   r   r-   r   r1   rO   r2   r%   re   r   r	   r6   r   r`   r8   rt   ru   rv   rw   rx   r   Trj   Z#test_sce_mean_weight_ii_3d_log_probr<   rg   rk   r*   r*   r+   6export_softmaxcrossentropy_mean_weights_ii_3d_log_probV  s6    
	    
zNSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_ii_3d_log_probc                  C  s   d} t d}tjjdddgdg| |d}t jd t jd	d
dt j	}t jj
dd
ddt j}t d|d d< t|||d}t|||g|gdd d S )Nr   r   r-   r   r1   r2   re   r   r	   r6   r`   r8   rm   Ztest_sce_mean_no_weight_ii_3dr<   rn   ro   r*   r*   r+   0export_softmaxcrossentropy_mean_no_weights_ii_3dy  s(    
	zHSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_ii_3dc                  C  s   d} t d}tjjdddgddg| |d}t jd	 t jd
ddt j	}t jj
d	dddt j}t d|d	 d	< t|||dd\}}t|||g||gdd d S )Nr   r   r-   r   r1   r2   r%   re   r   r	   r6   r`   r8   Trq   Z&test_sce_mean_no_weight_ii_3d_log_probr<   rn   rr   r*   r*   r+   9export_softmaxcrossentropy_mean_no_weights_ii_3d_log_prob  s2    
	   
zQSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_ii_3d_log_probc                  C  s   d} t d}tjjddddgdg| |d}t jd	 t jd
dddt j	}t jj
d	dddt j}t d|d	 d	 d	< t jdddddgt j	d}t||| ||d}t||||g|gdd d S )Nr   r   r-   r   r1   rO   r2   re   r   r	   r6      r	   r   r}   r8   rt   ru   rv   rw   rx   r   )r   r   r   Ztest_sce_mean_weight_ii_4dr<   rg   rh   r*   r*   r+   -export_softmaxcrossentropy_mean_weights_ii_4d  s6    
	    zESoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_ii_4dc                  C  s   d} t d}tjjddddgddg| |d	}t jd
 t jddddt j	}t jj
d
dddt j}t d|d
 d
 d
< t jdddddgt j	d}t||| ||dd\}}t||||g||gdd d S )Nr   r   r-   r   r1   rO   r2   r%   re   r   r	   r6   r}   r~   r8   rt   ru   rv   rw   rx   r   T)r   r   r   r    Z#test_sce_mean_weight_ii_4d_log_probr<   rg   rk   r*   r*   r+   6export_softmaxcrossentropy_mean_weights_ii_4d_log_prob  s8    
	

zNSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_ii_4d_log_probc                  C  s   d} t d}tjjdddgdg| |d}t jd t jd	d
ddt j	}t jj
dd
ddt j}t d|d d d< t||| |d}t|||g|gdd d S )Nr   r   r-   r   r1   r2   re   r   r	   r6   r}   r~   r8   r   r   Ztest_sce_mean_no_weight_ii_4dr<   rn   ro   r*   r*   r+   0export_softmaxcrossentropy_mean_no_weights_ii_4d  s2    
	   zHSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_ii_4dc                  C  s   d} t d}tjjdddgddg| |d}t jd	 t jd
dddt j	}t jj
d	dddt j}t d|d	 d	 d	< t||| |dd\}}t|||g||gdd d S )Nr   r   r-   r   r1   r2   r%   re   r   r	   r6   r}   r~   r8   Tr   r   r    Z&test_sce_mean_no_weight_ii_4d_log_probr<   rn   rr   r*   r*   r+   9export_softmaxcrossentropy_mean_no_weights_ii_4d_log_prob(  s4    
	    
zQSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_ii_4d_log_probc               	   C  s   d} t jjddddgdg| d}d\}}}}}}}tjd	 tj|||||||tj}	tjj	d	|||||||fd
tj
}
tj|tj}t|	|
|| d}t||	|
|g|gdd d S )Nr   r-   r   r1   rO   r2   r3   r	   r6      r   r6   r	      r   r8   rS   Z!test_sce_NCd1d2d3d4d5_mean_weightr<   r>   )r   rH   r!   r"   dim1dim2dim3dim4dim5r   rI   r   rJ   r*   r*   r+   .export_input_shape_is_NCd1d2d3d4d5_mean_weightJ  s2       zFSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3d4d5_mean_weightc               	   C  s   d} t jjddddgddg| d}d	\}}}}}}}tjd
 tj|||||||tj}	tjj	d
|||||||fdtj
}
tj|tj}t|	|
|| dd\}}t||	|
|g||gdd d S )Nr   r-   r   r1   rO   r2   r%   r3   r   r   r8   TrX   Z*test_sce_NCd1d2d3d4d5_mean_weight_log_probr<   r>   )r   rH   r!   r"   r   r   r   r   r   r   rI   r   r)   r%   r*   r*   r+   7export_input_shape_is_NCd1d2d3d4d5_mean_weight_log_probf  s>           
zOSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3d4d5_mean_weight_log_probc               	   C  s   d} t jjdddgdg| d}d\}}}}}}}tjd tj|||||||tj}	tjj	d|||||||fd	tj
}
t|	|
| d
}t||	|
g|gdd d S )Nr0   r-   r   r1   r2   r3   r   r   r8   r;   Z$test_sce_NCd1d2d3d4d5_none_no_weightr<   r>   )r   rH   r!   r"   r   r   r   r   r   r   rI   rJ   r*   r*   r+   1export_input_shape_is_NCd1d2d3d4d5_none_no_weight  s0       zISoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3d4d5_none_no_weightc               	   C  s   d} t jjdddgddg| d}d\}}}}}}}tjd	 tj|||||||tj}	tjj	d	|||||||fd
tj
}
t|	|
| dd\}}t||	|
g||gdd d S )Nr0   r-   r   r1   r2   r%   r3   r   r   r8   TrL   Z-test_sce_NCd1d2d3d4d5_none_no_weight_log_probr<   r>   )r   rH   r!   r"   r   r   r   r   r   r   rI   r)   r%   r*   r*   r+   :export_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob  s:          
zRSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_probc            
      C  s   d} t d}tjjddddgdg| |d}d	\}}}t jd
 t j|||t j	}t jj
d
|||fdt j}d|d
 d
< t j|t j	}t|||| |d}	t||||g|	gdd d S )Nr   r
   r-   r   r1   rO   r2   re   r	   r6   r   r   r8   r   r   r   Z%test_sce_NCd1_mean_weight_negative_iir<   rn   )
r   r   rH   r!   r"   r   r   rI   r   rJ   r*   r*   r+   2export_input_shape_is_NCd1_mean_weight_negative_ii  s8    

    zJSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1_mean_weight_negative_iic                  C  s   d} t d}tjjddddgddg| |d	}d
\}}}t jd t j|||t j	}t jj
d|||fdt j}d|d d< t j|t j	}t|||| |dd\}	}
t||||g|	|
gdd d S )Nr   r
   r-   r   r1   rO   r2   r%   re   r   r   r8   Tr   r   r   r    Z.test_sce_NCd1_mean_weight_negative_ii_log_probr<   rn   )r   r   rH   r!   r"   r   r   rI   r   r)   r%   r*   r*   r+   ;export_input_shape_is_NCd1_mean_weight_negative_ii_log_prob  s:    


	zSSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1_mean_weight_negative_ii_log_probc                  C  s   d} t d}tjjdddgdg| |d}d\}}}}}t jd	 t j|||||t j	}t jj
d	|||||fd
t j}	d|	d	 d	 d	 d	< t||	| |d}
t|||	g|
gdd d S )Nr0   r-   r   r1   r2   re   r	   r6   r   r   r6   r   r8   r   Z,test_sce_NCd1d2d3_none_no_weight_negative_iir<   rn   )r   r   rH   r!   r"   r   r   r   r   rI   rJ   r*   r*   r+   9export_input_shape_is_NCd1d2d3_none_no_weight_negative_ii  s8    
   zQSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3_none_no_weight_negative_iic                  C  s   d} t d}tjjdddgddg| |d}d	\}}}}}t jd
 t j|||||t j	}t jj
d
|||||fdt j}	d|	d
 d
 d
 d
< t||	| |dd\}
}t|||	g|
|gdd d S )Nr0   r   r-   r   r1   r2   r%   re   r   r   r8   Tr   Z5test_sce_NCd1d2d3_none_no_weight_negative_ii_log_probr<   rn   )r   r   rH   r!   r"   r   r   r   r   rI   r)   r%   r*   r*   r+   Bexport_input_shape_is_NCd1d2d3_none_no_weight_negative_ii_log_prob  s:    
    
zZSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3_none_no_weight_negative_ii_log_probc            	      C  s   d} t d}tjjddddgdg| |d}d	\}}t jd
 t j||t j	}t jj
d
||dt j}d|d
< t j|t j	}t|||| |d}t||||g|gdd d S )Nr   
   r-   r   r1   rO   r2   re   r	   r6   r   r8   r   Z$test_sce_NCd1d2d3_sum_weight_high_iir<   rn   )	r   r   rH   r!   r"   r   rI   r   rJ   r*   r*   r+   1export_input_shape_is_NCd1d2d3_sum_weight_high_ii?  s8    
    zISoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3_sum_weight_high_iic            
      C  s   d} t d}tjjddddgddg| |d	}d
\}}t jd t j||t j	}t jj
d||dt j}d|d< t j|t j	}t|||| |dd\}}	t||||g||	gdd d S )Nr   r   r-   r   r1   rO   r2   r%   re   r   r   r8   Tr   Z-test_sce_NCd1d2d3_sum_weight_high_ii_log_probr<   rn   )
r   r   rH   r!   r"   r   rI   r   r)   r%   r*   r*   r+   :export_input_shape_is_NCd1d2d3_sum_weight_high_ii_log_prob^  s:    

	zRSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3_sum_weight_high_ii_log_probN)&__name__
__module____qualname__staticmethodrK   rN   rW   rZ   r[   r\   r]   r_   ra   rb   rc   rd   ri   rl   rp   rs   ry   rz   r{   r|   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r*   r*   r*   r+   r-   ]   s    "! "!"'!!#r-   )Nr   NN)
__future__r   Znumpyr   r?   Zonnx.backend.test.case.baser   Zonnx.backend.test.case.noder   r,   r-   r*   r*   r*   r+   <module>   s          
P