U
    h                     @   sx   d dl mZ d dlmZmZ d dlZd dlmZ d dl	m
Z
 d dlmZmZmZ d dlmZ eeZG dd	 d	eZdS )
    )	getLogger)TupleUnionN)Fusion)NumpyHelper)	NodeProtoTensorProtohelper)	OnnxModelc                
       s   e Zd ZdZeeeeeed fddZd"eeedddZ	d	d
 Z
d#eeeeeef dddZeeeeeeeeedf dddZeeeeeeeeedf dddZdd Zdd Zdd Zdd Zdd Zedd d!Z  ZS )$FusionAttentionUnetzB
    Fuse Attention subgraph of UNet into one Attention node.
    )modelhidden_size	num_headsis_cross_attentionenable_packed_qkvenable_packed_kvc                    sH   t  ||rdnddg || _|| _|| _|| _|| _d| _d| _d S )NMultiHeadAttention	AttentionZLayerNormalizationT)	super__init__r   r   r   r   r   num_heads_warninghidden_size_warning)selfr   r   r   r   r   r   	__class__ R/tmp/pip-unpacked-wheel-socb9apf/onnxruntime/transformers/fusion_attention_unet.pyr      s    	zFusionAttentionUnet.__init__F)	reshape_q	is_torch2returnc                 C   s   d}|rj| j |d}|r|jdkrt|jdkr| j |jd }t|tjrt	|j
dgkrt|}n:| j |jd }t|tjrt	|j
dgkrt|d }t|tr|dkr|S dS )zDetect num_heads from a reshape node.

        Args:
            reshape_q (NodeProto): reshape node for Q
            is_torch2 (bool): graph pattern is from PyTorch 2.*
        Returns:
            int: num_heads, or 0 if not found
        r      Concat      )r   Z
get_parentop_typeleninputZget_constant_value
isinstancenpZndarraylistshapeint)r   r   r   r   Zreshape_parentZq_shape_valuer   r   r   get_num_heads*   s    	
z!FusionAttentionUnet.get_num_headsc                 C   s*   | j |jd }|r&t|jd S dS )zDetect hidden_size from LayerNormalization node.
        Args:
            layernorm_node (NodeProto): LayerNormalization node before Q, K and V
        Returns:
            int: hidden_size, or 0 if not found
        r#   r   )r   get_initializerr&   r   to_arrayr*   )r   layernorm_nodeZlayernorm_biasr   r   r   get_hidden_sizeF   s    z#FusionAttentionUnet.get_hidden_size)r   r/   r   r   c                 C   s   |  ||}|dkr| j}| jdkrT|| jkrT| jrTtd| j d| d d| _| |}|dkrl| j}| jdkr|| jkr| jrtd| j d| d d| _||fS )aF  Detect num_heads and hidden_size.

        Args:
            reshape_q (NodeProto): reshape node for Q
            is_torch2 (bool): graph pattern is from PyTorch 2.*
            layernorm_node (NodeProto): LayerNormalization node before Q, K, V
        Returns:
            Tuple[int, int]: num_heads and hidden_size
        r   z--num_heads is z. Detected value is z. Using detected value.Fz--hidden_size is )r,   r   r   loggerwarningr0   r   r   )r   r   r/   r   r   r   r   r   r   get_num_heads_and_hidden_sizeS   s"    
z1FusionAttentionUnet.get_num_heads_and_hidden_sizeN)q_matmulk_matmulv_matmulr   r   r&   outputr   c           $   
   C   sh  | j  }|r^|jd |ks6|jd |ks6|jd |krtd|jd |jd |jd  dS nV|jd |ks|jd |jd ks|jd |krtd|jd |jd |jd  dS |dkr|| dkrtd| d|  dS | j|jd }	| j|jd }
| j|jd }|	r*|
r*|s.dS |	jdkrHtd	 dS t|	}t|
}t|}td
|j	 d|j	 d|j	 d|  |r|j	|j	ks|j	|j	krdS |j	d }|dkr||krt
d| d| dtt|j	dd }| jrN| jd}|}|}|| }t||||||||||||g||d | }| jjddd}| j|d tj|j	d |j	d g|d tjd|jd |d g|d g|d}| j| j|j< | j|d tjdgdd|d|gdd tjd|d |d g|d g|d  d}| j| j|j< | j||g | j|||g nDtj|||fdd!}d| }| jd"}| j|d# tj||g|d n| jd}| jr@|j	|j	krdS |j	d }|j	d }||kst |j	d }|j	d }|j	d }||kr||kst |}|}|| }t||||||||g||d$ | }| jjdd%d}| j|d tj|j	d |j	d g|d tjd|jd |d g|d g|d}| j| j|j< | j|d tjdgdd|d$|gdd tjd|d |d g|j!d g|d  d}| j| j|j< | j||g | j||g tj"d|gtj#d&}d| } | j|d' tj| g|d |r| js||d# |d' g}!n
|d g}!n@| js|j!d |j!d |j!d |d' g}!n|j!d |j!d g}!tj|r| jsd"nd|!|g|d}"d(|"_$|"j%t&d)|g |r8| js8d*n d+'| jrHd,n| jrTd-nd.}#| (|# |"S )/  Create an Attention node.

        Args:
            q_matmul (NodeProto): MatMul node in fully connection for Q
            k_matmul (NodeProto): MatMul node in fully connection for K
            v_matmul (NodeProto): MatMul node in fully connection for V
            num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
            hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
            input (str): input name
            output (str): output name

        Returns:
            Union[NodeProto, None]: the node created or None if failed.
        r   RFor self attention, input hidden state for q and k/v shall be same. Got %s, %s, %sNXFor cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %sinput hidden size # is not a multiple of num of heads r    
   Bweights are in fp16. Please run fp16 conversion after optimizationqw= kw= vw= hidden_size=Input hidden size (,) is not same as weight dimension of q,k,v (:). Please provide a correct input hidden size or pass in 0r      MatMul
MatMul_QKVZname_prefix_weightname	data_typedimsvals_outinputsoutputsrL   _reshape_shape   FrL   rM   rN   rO   rawReshape_input_reshape)axisr   Z_qkv_weightr#   	MatMul_KVZdtype	_qkv_biascom.microsoftr   Attention (self attention)MultiHeadAttention ({})self attention with packed qkvcross attention with packed kvcross attention))r   r&   r1   debugr   r-   rM   r   r.   r*   
ValueErrorr+   r(   prodr   create_node_namedstackreshapeadd_initializerr   FLOATr	   	make_nodethis_graph_namenode_name_to_graph_namerL   INT64nodes_to_addextendnodes_to_removestackr   AssertionErrorr7   zerosfloat32domain	attributemake_attributeformatincrease_counter)$r   r4   r5   r6   r   r   r&   r7   is_self_attentionq_weightk_weightv_weightqwkwvw
qw_in_sizeqw_out_sizeattention_node_namecnh
qkv_weightmatmul_node_namematmul_nodereshape_nodeZqkv_weight_dim
kw_in_size
vw_in_sizekw_out_sizevw_out_size	kv_weightqkv_biasqkv_bias_dimattention_inputsattention_nodecounter_namer   r   r   create_attention_nodeu   sj   *0



(
. 






2

z)FusionAttentionUnet.create_attention_node)q_matmul_addk_matmul_addv_matmul_addr   r   r&   r7   r   c           F   
   C   s&  | j  }| j|dd}	| j|dd}
| j|dd}| |}|dkrNdS |\}}| |}|dkrldS |\}}| |}|dkrdS |\}}|r@|	jd |ks|
jd |ks|jd |krtd|	jd |
jd |jd  dS |jd |ks|jd |ks|jd |krtd|jd |jd |jd  dS n|	jd |ksv|
jd |jd ksv|
jd |krtd|	jd |
jd |jd  dS |jd |ks|jd |jd ks|
jd |krtd|jd |jd |jd  dS |dkr*|| dkr*td| d	|  dS | j|	jd
 }| j|
jd
 }| j|jd
 }|rr|rr|svdS |jdkrtd dS t	
|}t	
|}t	
|}td|j d|j d|j d|  |r|j|jks|j|jkrdS |jd }|dkr0||kr0td| d| dtt|jd
d }| jr| jd}|}|}|| } t|||| |||| |||| g||d |  }!| jjddd}"| j|"d tj|!jd |!jd
 g|!d tjd|
jd |"d g|"d g|"d}#| j| j|#j< |jd }$| j|$tjdgdd|| gdd | jjd d!d}%tjd |jd |$g|%d g|%d}&| j| j|&j< | jjd d"d}'tjd |jd |$g|'d g|'d}(| j| j|(j< | jjd d#d})tjd |jd |$g|)d g|)d}*| j| j|*j< | jjd$d%d}+tjd$|&jd |(jd |*jd g|+d g|+d},|,jt d&dg | j| j|,j< |,jd }-| j|-tjdgdd|d |  gdd | jjd d'd}.tjd |,jd |-g|.d g|.d}/| j| j|/j< | jjd(d)d}0tjd(|/jd |#jd g|0d g|0d}1| j| j|1j< |0d }2| j|2tjd*gdd|d| gdd tjd |1jd |2g|d+ g|0d, d}3| j| j|3j< | j!|#|&|(|*|,|/|1|3g | j"|	|
||||g ndS nz| jd}| j#
r,|j|jkrdS |jd }4|jd }5|4|5kst$|jd
 }|jd
 }6|jd
 }7||7kr0|6|7ks4t$|4}|}|6| } t|||| |||| g||d- |  }8| jjdd.d}"| j|"d tj|8jd |8jd
 g|8d tjd|
jd |"d g|"d g|"d}#| j| j|#j< |jd }9| j|9tjdgdd|| gdd | jjd d"d}'tjd |jd |9g|'d g|'d}(| j| j|(j< | jjd d#d})tjd |jd |9g|)d g|)d}*| j| j|*j< | jjd$d/d}:tjd$|(jd |*jd g|:d g|:d};|;jt d&dg | j| j|;j< |;jd }<| j|<tjdgdd|d- |  gdd | jjd d0d}=tjd |;jd |<g|=d g|=d}>| j| j|>j< | jjd(d1d}?tjd(|>jd |#jd g|?d g|?d}@| j| j|@j< |?d }2| j|2tjd*gdd|d-| gdd tjd |@jd |2g|d2 g|?d, d}3| j| j|3j< | j!|#|(|*|;|>|@|3g | j"|
|||g ndS tj%d|gtj&d3}Ad| }B| j|d4 tj|Bg|Ad |
r| j
sxdS |d+ g}Cn| j#
sdS |jd |d2 g}Ctj|
r| j
sd5nd|C|g|d}Dd6|D_'|Djt d7|g |
r| j
sd8n d9(| jrd:n| j#rd;nd<}E| )|E |DS )=r8   rG   r   Nr9   z_For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %sr:   zeFor cross attention, input hidden state for LoRA q and k/v weights shall be different. Got %s, %s, %sr;   r<   r    r=   r>   r?   r@   rA   rB   rC   rD   rE   r   rF   rH   rI   rJ   rK   rP   rQ   rT   r"   FrV   rX   ZReshape_LoRA_QZReshape_LoRA_KZReshape_LoRA_Vr!   ZConcat_LoRA_QKVr[   ZReshape_LoRA_QKVAddZAdd_Weights_QKVrU   Z
_qkv_inputrZ   r#   r\   ZConcat_LoRA_KVZReshape_LoRA_KVZAdd_Weights_KVZ	_kv_inputr]   r^   r   r_   r   r`   ra   rb   rc   rd   )*r   r   match_parentmatch_lora_pathr&   r1   re   r-   rM   r   r.   r*   rf   r+   r(   rg   r   rh   ri   rj   rk   r   rl   r	   rm   rn   ro   rL   rp   r7   ry   rr   rz   rq   rs   r   ru   rv   rw   rx   r{   r|   )Fr   r   r   r   r   r   r&   r7   r}   r4   r5   r6   Zq_lora_nodesZq_lora_last_nodeZq_lora_matmul_1Zk_lora_nodesZk_lora_last_nodeZk_lora_matmul_1Zv_lora_nodesZv_lora_last_nodeZv_lora_matmul_1r~   r   r   r   r   r   r   r   r   r   r   r   r   r   r   Zlora_weight_shape_tensor_nameZq_lora_reshape_node_nameZq_lora_reshape_nodeZk_lora_reshape_node_nameZk_lora_reshape_nodeZv_lora_reshape_node_nameZv_lora_reshape_nodeZqkv_lora_concat_node_nameZqkv_lora_concat_nodeZ'reshaped_lora_weights_shape_tensor_nameZqkv_lora_reshaped_node_nameZqkv_lora_reshaped_nodeZadd_weights_node_nameZadd_weights_nodeZshape_tensor_namer   r   r   r   r   r   Z kv_lora_weight_shape_tensor_nameZkv_lora_concat_node_nameZkv_lora_concat_nodeZ*reshaped_kv_lora_weights_shape_tensor_nameZkv_lora_reshaped_node_nameZkv_lora_reshaped_nodeZadd_kv_weights_node_nameZadd_kv_weights_noder   r   r   r   r   r   r   r   create_attention_node_loras  s   


*6	



(
. 


	






2

	

z.FusionAttentionUnet.create_attention_node_lorac              	   C   s  | j |dd}|d kr.| js.| j |dd}|d kr:d S |jd }|| }d }|D ]}|jdkrT|} qlqT|d krxd S | ||p| ||}	|	d k	r|	\}
}}}}}}|}| |||
\}}|dkrt	d d S | j
||||||jd |jd d}|d krd S n| ||p$| ||}	|	d kr4d S |	\}
}}}}}}|}| |||
\}}|dkrtt	d d S | j||||||jd |jd d}|d krd S | |||
\}}|dkrt	d d S | j| | j| j|j< | j||g d| _d S )Nr   r   rX   z*fuse_attention: failed to detect num_heads)r&   r7   T)r   r   r   r7   r$   match_qkv_torch1match_qkv_torch2r3   r1   re   r   match_qkv_torch1_loramatch_qkv_torch2_lorar   rq   appendrn   ro   rL   rs   rr   Zprune_graph)r   Znormalize_nodeZinput_name_to_nodesZoutput_name_to_nodeZnode_before_layernorm
root_inputZchildren_nodesskip_addnodeZ	match_qkvr   reshape_qkvtranspose_qkvr   matmul_qmatmul_kmatmul_vZattention_last_nodeZq_num_headsZq_hidden_sizeZnew_nodematmul_add_qmatmul_add_kmatmul_add_vr   r   r   fuseI  s    



	
 


	


zFusionAttentionUnet.fusec              
   C   s  |j d |krdnd}| j|ddddddg|dddddg}|dkrJdS |\}}}}}}| j|ddddgddddg}	|	dkrtd dS |	\}}}}
| j|d	d
dgdddg}|dk	r|\}}}nF| j|d	dd
dgddddg}|dk	r|\}}}}ntd dS | j|ddddgddddg}|dkrJtd dS |\}}}}| j|dddddgdddddg}|dkrtd dS |\}}}}}d||||||
fS )z.Match Q, K and V paths exported by PyTorch 1.*r   r    r   rG   rX   	TransposeN&fuse_attention: failed to match v pathSoftmaxMul'fuse_attention: failed to match qk path&fuse_attention: failed to match q path&fuse_attention: failed to match k pathFr&   r   match_parent_pathr1   re   )r   r   r   another_input	qkv_nodes_r   r   
matmul_qkvv_nodesr   qk_nodes_softmax_qk_mul_qk	matmul_qk	_add_zeroq_nodes_transpose_qr   r   k_nodesr   r   r   r   r     sJ     
 

 

  

z$FusionAttentionUnet.match_qkv_torch1c                 C   s  |j d |krdnd}| j|dddddg|ddddg}|dkrFdS |\}}}}}| j|dddgdddg}	|	dkrtd dS |	\}}}
| j|d	dgddg}|dk	r|\}}ntd
 dS | j|ddddgddddg}|dkrtd dS |\}}}}| j|ddddgddddg}|dkrBtd dS |\}}}}| j|ddddddddgddddddddg}|dks|d |krtd dS d||||||
fS )z.Match Q, K and V paths exported by PyTorch 2.*r   r    r   rG   rX   r   Nr   r   r   r   r   r   SqrtDivCastSliceShapez*fuse_attention: failed to match mul_q pathTr   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   mul_qr   r   r   r   _mul_kr   mul_q_nodesr   r   r   r     sL    



 
 


z$FusionAttentionUnet.match_qkv_torch2c                 C   s  |j d |krdnd}| j|dddddddg|ddddddg}|dkrNdS |\}}}}}}}| j|ddddgddddg}	|	dkrtd dS |	\}}}}
| j|d	d
dgdddg}|dk	r|\}}}nF| j|d	dd
dgddddg}|dk	r
|\}}}}ntd dS | j|ddddgddddg}|dkrPtd dS |\}}}}| j|dddddgdddddg}|dkrtd dS |\}}}}}d||||||
fS )zJMatch Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*r   r    r   rG   rX   r   N+fuse_attention: failed to match LoRA v pathr   r   ,fuse_attention: failed to match LoRA qk path+fuse_attention: failed to match LoRA q path+fuse_attention: failed to match LoRA k pathFr   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r     sJ     
 

 

  

z)FusionAttentionUnet.match_qkv_torch1_lorac                 C   s  |j d |krdnd}| j|ddddddg|dddddg}|dkrJdS |\}}}}}}| j|dddgdddg}	|	dkrtd dS |	\}}}
| j|d	dgddg}|dk	r|\}}ntd
 dS | j|ddddgddddg}|dkrtd dS |\}}}}| j|ddddgddddg}|dkrJtd dS |\}}}}| j|ddddddddgddddddddg}|dks|d |krtd dS d||||||
fS )zJMatch Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*r   r    r   rG   rX   r   Nr   r   r   r   r   r   r   r   r   r   r   r   z/fuse_attention: failed to match LoRA mul_q pathTr   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   5  sL    



 

 


z)FusionAttentionUnet.match_qkv_torch2_lora)add_nodec                 C   s   | j |ddgddg}|d k	r0|\}}||fS | j |dddgdddg}|d k	rf|\}}}||fS | j |ddddgddddg}|d k	r|\}}}}||fS d S )NrG   r    r   r   )r   r   )r   r   Z
lora_nodesZlora_matmul_2_nodeZlora_matmul_1_nodeZlora_mul_noder   r   r   r   r   h  s2    



z#FusionAttentionUnet.match_lora_path)F)F)__name__
__module____qualname____doc__r
   r+   boolr   r   r,   r0   r   r3   strr   r   r   r   r   r   r   r   r   __classcell__r   r   r   r   r      s^      
$
  
   YW1405r   )loggingr   typingr   r   Znumpyr(   Zfusion_baser   Zfusion_utilsr   Zonnxr   r   r	   Z
onnx_modelr
   r   r1   r   r   r   r   r   <module>   s   