U
    h                     @   s   d dl Z d dlmZmZ d dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZ d dlmZ e eZG dd deZG d	d
 d
eZdS )    N)OptionalUnion)FusionAttention)Fusion)FunctionProto	NodeProtoTensorProtohelpernumpy_helper)	OnnxModelc                       sv   e Zd ZdZeeed fddZdeeeeeeeeeeee	e
 eedf ddd	Zd
d Zdd Zdd Z  ZS )FusionRotaryAttentionze
    Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
    )modelhidden_size	num_headsc              
      s$   t  j|||ddddddgd d S )NTZSimplifiedLayerNormalization SkipSimplifiedLayerNormalizationZLayerNormalizationSkipLayerNormalizationAdd)Zuse_multi_head_attentionZsearch_op_types)super__init__)selfr   r   r   	__class__ T/tmp/pip-unpacked-wheel-socb9apf/onnxruntime/transformers/fusion_rotary_attention.pyr      s    zFusionRotaryAttention.__init__ N)inputoutputq_rotaryk_rotaryv_matmul	attn_maskadd_qkpast_kpast_v	present_k	present_vscalereturnc                 C   s  | j dkst| jdkrF| j| j  dkrFtd| j d| j   d S | jd}|jd |jd |jd d||||	g}|g}|
r|r||
|g t	j
d|||d}d|_|jt	d| j g |d k	r|jt	d	|g | jd k	r
|jt	d
t| jg | d |S )Nr   z)fuse_rotary_attention: input hidden size z# is not a multiple of num of heads ZMultiHeadAttentionr   )inputsoutputsnamecom.microsoftr   r&   mask_filter_value)r   AssertionErrorr   loggerdebugr   create_node_namer   extendr	   	make_nodedomain	attributeZmake_attributer,   floatincrease_counter)r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   Zmha_node_nameZ
mha_inputsZmha_outputsZmha_noder   r   r   create_mha_node)   sB    
z%FusionRotaryAttention.create_mha_nodec	           1      C   sx  | j |dgdg}	| j |dgdg}
|	d ks8|
d kr<dS |	d |
d  }}| j |dddgdddg}| j |dddgdddg}| j |dddgdddg}| j |dddgdddg}|d ks|d ks|d ks|d krdS |\}}}|\}}}|jd |ks|jd |krdS |d j|jks>|d j|jkrBdS | j |dgdg}| j |dgdg}|d ks~|d krdS |d |d  }}| j |dd	ddgddddg}| j |dd
ddgddddg}| j |dddgdddg}| j |dddgdddg}|d ks4|d ks4|d ks4|d kr8dS |d j|jks|d j|jks|d j|jks|d j|jkrdS | j |dgdg}|d krdS |d }| j |dd	ddgddddg} | j |dd
ddgddddg}!| d ks|!d krdS | d j|jks*|!d j|jkr.dS | j |dgdg}"|"d krPdS |"d }#| j |#dd	ddgddddg}$| j |#dddgdddg}%|$d ks|%d krdS |$d j|jks|%d j|jkrdS |$d }&| d }'|d }(|jd })|&jd |)ks&|'jd |)ks&|(jd |)kr*dS | j |dddgdddg}*| j |ddddgddddg}+|*d k	r||*\}},}-n|+d k	r|+\}}},}-ndS |-jd dkrdS | j |,dd
ddgddddg}.| j |-dd
ddgddddg}/| j |-dgdg}0|.d ks|/d ks|0d kr"dS |.d j|/d jksN|.d j|/d jkrRdS |/d jd |0d jd krtdS dS )NConcat   Fr   	UnsqueezeGatherShape   Mulr   SliceCast>   attention_maskr    T)r   match_parent_pathr   r*   r   )1r   reshape_qkv_2reshape_qkv_1reshape_q_2reshape_k_2reshape_v_2reshape_v_1r!   
root_inputZconcat_qkv_2_pathZconcat_qkv_1_pathZconcat_qkv_2Zconcat_qkv_1Zreshape_qkv_2_path_1Zreshape_qkv_2_path_2Zreshape_qkv_1_path_1Zreshape_qkv_1_path_2_gather_1shape_1gather_2shape_2Zconcat_v_2_pathZconcat_v_1_pathZ
concat_v_2Z
concat_v_1Zreshape_v_2_path_1Zreshape_v_2_path_2Zreshape_v_1_path_1Zreshape_v_1_path_2Zconcat_k_2_pathZ
concat_k_2Zreshape_k_2_path_1Zreshape_k_2_path_2Zconcat_q_2_pathZ
concat_q_2Zreshape_q_2_path_1Zreshape_q_2_path_2Zmul_qZmul_kZmul_vZgather_1_outZattn_mask_path_1Zattn_mask_path_2Z
slice_qk_2Z
slice_qk_1Zslice_qk_2_pathZslice_qk_1_path_1Zslice_qk_1_path_2r   r   r   &check_runtime_shape_paths_for_functiona   s    

 $ 
 
 
 

 
 
 
 
$
 
 
$
0 

 
 
 
 
,z<FusionRotaryAttention.check_runtime_shape_paths_for_functionc                 C   s  | j |dgdg}|d kr dS |d }| j |dddgdddg}| j |dddgdddg}	|d ksp|	d krtdS |\}
}}|	\}
}}|jd |ks|jd |krdS | j |dgdg}|d krdS |d }| j |dddgdddg}| j |dddgdddg}|d ks|d kr dS |d j|jksD|d j|jkrHdS | j |dgdg}|d krjdS |d }| j |dddgdddg}| j |dddgdddg}|d ks|d krdS |d j|jks|d j|jkrdS | j |dgdg}|d krdS |d }| j |dddgdddg}| j |dddgdddg}|d ks`|d krddS |d j|jks|d j|jkrdS dS )	Nr8   r9   Fr   r:   r;   r<   T)r   rB   r   r*   )r   reshape_qkv	reshape_q	reshape_k	reshape_vrI   Zconcat_qkv_pathZ
concat_qkvZreshape_qkv_path_1Zreshape_qkv_path_2rJ   rK   rL   rM   rN   concat_v_pathconcat_vZreshape_v_path_1Zreshape_v_path_2concat_k_pathconcat_kZreshape_k_path_1Zreshape_k_path_2Zconcat_q_pathZconcat_qZreshape_q_path_1Zreshape_q_path_2r   r   r   #check_runtime_shape_paths_for_nodes   sV    	

$
$
$z9FusionRotaryAttention.check_runtime_shape_paths_for_nodesc           B      C   s  |j dkrd S d }| j|dddddgdddddg}| j|ddddgddddg}| j|dddddgdddddg}|d k	r|\}}	}}
}|}nD|d k	r|\}}}}|}n*|d k	r|\}}}}}|}ntd d S d	\}}}d }| j|ddd
dddgddddddg}| j|d
dddgddddg}| j|dddgdddg}| jj|dddd
dddgdddddddgfdddddd
dddd
dddgdddddddddddddgfddddddddd
dddd
dddgddddddddddddddddgfddddddd
dddd
dddgddddddddddddddgfddddd
dddd
dddgddddddddddddgfdd
dddd
dddg	dddddddddg	fdd
ddddd
dddg
ddddddddddg
fdd
dddd
dddg	dddddddddg	fdd
dddd
dddg	dddddddddg	fg	d d\}}}|d k	r|\}}}}}}|}| j|ddgddg}|d krttd d S |d jd }|d jd }|jd }n|d k	r|\}}}}|}|jd }|jd }n||d k	r|\}}}|}|jd }nX|d k	r:t|dkr:|d dd  \}}}}|}|jd }|jd }ntd d S | j|ddddgddddg}d \}}|d k	r|\}}}}ntd! d S d"\}} | j|d
ddgdddg}!| j|d#d
ddgddddg}"| j|ddd$d#dddgdddddddg}#| j|dd$d#dddgddddddg}$| j|dddd$d#dddgddddddddg}%| j|ddd$d#dddgdddddddg}&|!d k	r|!\}}'}(|'jd }n|"d k	r|"\}}}'}(|'jd }n|#d k	r| 	|#d jd } nb|$d k	r| 	|$d jd } nB|%d k	r&|%d jd } n(|&d k	r@|&d jd } ntd% d S d"\})}*d }+| j|ddd
dd&dgddddddg},| j|dd&dddgdddddg}-| j|dd
d&dddgddddddg}.| jj|ddddd
d&dddg	dddddddddg	fddddddd
dddd
d&dddgdddddddddddddddgfdddddddddd
dddd
d&dddgddddddddddddddddddgfdddddddd
dddd
d&dddgddddddddddddddddgfdddddd
dddd
d&dddgddddddddddddddgfddd
dddd
d&dddgdddddddddddgfddd
ddddd
d&dddgddddddddddddgfddd
dddd
d&dddgdddddddddddgfddd
dddd
d&dddgdddddddddddgfg	d d\}}/}|,d k		rt|,\}0}}1}}2}3|,}+| j|1ddgddg}4|4d k	r>td' d S |4d jd })|4d jd }5|1jd }*||5k
s8t
n|-d k		r|-\}}2}}6}3|-}+|2jd }*n|.d k		r|.\}}1}2}}6}3|.}+|1jd })|1jd }*nh|/d k	
r*t|/dk
r*|/d d(d  \}6}3|/d d)d* \}1}2|/}+|1jd })|1jd }*ntd+ d S d }7| j|ddd&dgddddg}8| j|d&dddgddddg}9|8d k	
r|8\}:}};}<|8}7n*|9d k	
r|9\};}}=}<|9}7ntd, d S |<jd |3jd k
r|3jd |jd k
rtd- d S d.}>||krD| |	|
|:|0||||<jd s8td/ d S |	jd }>n|||fkr| ||=|6||<jd sztd/ d S |jd }>|<jd |;jd< |3jd |2jd< |2jd0 |2jd< ||kr|dd  }| |<jd |>|;|2||| |)||*|}?|?d krtd1 d S | j|? | j| j|?j< | j|dd   ||krT| j|d d  n&|d d g}@|D ]}A| |A|@ qf| j| |+|,kr| j|+d d(  n|+|-kr| j|+d  | j|+d  | j|+d  n|+|.kr.| j|+d  | j|+d  | j|+d  | j|+d  n:|+|/krh|+d d |+d d g}@|+D ]}A| |A|@ qT|7|8kr| j|7d d(  n*|7|9kr| j|7d  | j|7d  d2| _d S )3N>   r   r   r   MatMulReshape	Transposer9   r   Z	AllReducez0fuse_rotary_attention: failed to match qkv nodes)r   r   r   r8   ZExpandr:   ZWhereZEqualr;   r<   r>   ZConstantOfShape   r=      )output_name_to_noder?   zDfuse_rotary_attention: failed to match past/present concat in v path	   z-fuse_rotary_attention: failed to match v pathZSoftmaxr   DivNNz/fuse_rotary_attention: failed to match qk nodes)r   r   r@   ZSubz;fuse_rotary_attention: failed to match attention mask nodesRotaryEmbeddingzDfuse_rotary_attention: failed to match past/present concat in k pathz.fuse_rotary_attention: failed to match k nodesz.fuse_rotary_attention: failed to match q nodeszKfuse_rotary_attention: failed to find the same root_input for q, k, v pathsr   z;fuse_rotary_attention: failed to verify runtime shape pathsZ	_output_0zSfuse_rotary_attention: failed to create multi-head attention with rotary embeddingsT)op_typer   rB   r.   r/   Zmatch_parent_paths_allr   r   lenZreshape_add_qkr-   rO   rX   r*   r7   nodes_to_addappendthis_graph_namenode_name_to_graph_namenodes_to_remover1   Z&add_nodes_to_remove_with_nodes_to_keepprune_graph)Br   Znormalize_nodeinput_name_to_nodesr^   Z	qkv_nodesZqkv_nodes_1Zqkv_nodes_2Zqkv_nodes_3rJ   rC   rD   Z
matmul_qkvrP   r#   r%   Zpast_seq_lenZv_nodesZ	v_nodes_1Z	v_nodes_2Z	v_nodes_3Z	v_nodes_4rG   rU   rH   Zmatmul_vrT   Ztranspose_vrS   Zqk_nodesr!   Z	matmul_qkr    Z
add_qk_strZattn_mask_nodes_1Zattn_mask_nodes_2Zattn_mask_nodes_3Zattn_mask_nodes_4Zattn_mask_nodes_5Zattn_mask_nodes_6Zslice_mask_1Zslice_mask_2r"   r$   Zk_nodesZ	k_nodes_1Z	k_nodes_2Z	k_nodes_3Z	k_nodes_4rF   rW   Zrotary_kZmatmul_krV   Zshared_past_seq_lenrR   Zq_nodesZ	q_nodes_1Z	q_nodes_2rE   Zrotary_qZmatmul_qrQ   Zroot_outputZnew_nodeZnodes_to_keepZ	temp_pathr   r   r   fuseF  s&   






"lp






















 &"  %  )















,















zFusionRotaryAttention.fuse)r   r   r   r   r   r   N)__name__
__module____qualname____doc__r   intr   strr   r   r5   r   r7   rO   rX   rq   __classcell__r   r   r   r   r      s>          
8 Ir   c                       s^   e Zd Zed fddZeedddZeddd	Ze	e	e	e	e	d
ddZ
dd Z  ZS )FusionRotaryEmbeddings)r   c                    s*   d| _ t || j | j | j d dg d S )Nrd   z.1r   )	base_namer   r   )r   r   r   r   r   r     s    zFusionRotaryEmbeddings.__init__)rot_emb_nodefunctionc                    s   g g  }}|j D ]X}|jdkr|jg kr|jd |jkr|| t|j|jd }||j|  qg }|D ]6}|jd j}	| j	
d|	_| j	|	 ||	j qrt||D ]>\ }
tt fdd| j	j	jj }|D ]}t| |
 qq|S )NConstantr   c                    s
    | j kS N)r   )entryZextra_outputr   r   <lambda>      z?FusionRotaryEmbeddings.reassign_extra_outputs.<locals>.<lambda>)noderh   r   r   rk   listindexr4   tr   r0   r*   add_initializerzipfiltergraphr   Zreplace_node_input)r   r{   r|   Zextra_constantsextra_outputsZfn_nodeZoutput_indexZextra_initializersZextra_constantZconstant_tensorprotoZextra_initializerZnodes_to_updateZnode_to_updater   r   r   reassign_extra_outputs  s"    

$
z-FusionRotaryEmbeddings.reassign_extra_outputsr   c                    sB  | j | j}| j ddgddg}|d k	r8|\}}ntd d S |jd jd g}tt	fdd| j j j
j}tt	fdd| j j j
j}d	\}	}
t|dkrt|dkr| j |	d kr| j |
d krt|d jd j }t|d jd j }tj|	tjt|j|  d
}| j || j tj|
tjt|j|  d
}| j || j | j|d |d g ||	|
g j}t|dkrtt	fdd| j j j}t|dkst|  |d  tt	 fdd|}t|dksttj!| j|||dd}d|_"| j#| |S )NrZ   rY   r   z.fuse_rotary_embeddings: failed to match MatMulr9   c                    s   | j d  jd kS )Nr   r=   r   r   Zconstantr   r   r   r     r   zOFusionRotaryEmbeddings.create_rotary_embeddings_from_function.<locals>.<lambda>c                    s   | j d  jd kS )Nr   r\   r   r   r   r   r   r     r   	cos_cache	sin_cacher*   Z	data_typeZdimsvalsc                    s   | j  jkS r~   )r*   rh   )fnr   r   r   r     r   c                    s   |  kS r~   r   )Zoutput_name)r   r   r   r   
  r   r(   r)   r*   Zinterleavedr+   )$r   r0   rz   rB   r.   r/   r   r   r   r   r   r   ri   get_initializerr
   to_arrayr4   r   squeezer	   make_tensorr   FLOATshapeflattentolistr   rl   rn   r1   Z	functionsr-   r   r2   r3   rk   )r   r   rotary_emb_node_nameZmatmul_pathZreshape_nodeZmatmul_nodeZrotary_emb_inputscos_cache_nodesin_cache_nodecos_cache_namesin_cache_namer   r   cos_cache_tensorsin_cache_tensorZrotary_emb_outputsfuncrotary_emb_noder   )r   r   r   &create_rotary_embeddings_from_function  sv    





z=FusionRotaryEmbeddings.create_rotary_embeddings_from_function)rI   position_ids	cos_slice	sin_slicer   c                    s  | j | j}tt fdd| j j jj}ttfdd| j j jj}d\}	}
t|dkr|t|dkr|| j |	d kr|| j |
d kr|t	
|d jd j }t	
|d jd j }|jd }|d d d |d f }|d d d |d f }tj|	tjt|j|  d}| j || j tj|
tjt|j|  d}| j || j | j|d |d g tj| j|||	|
g|g|dd	}d
|_|S )Nc                    s   | j d  kS Nr   r   r   )r   r   r   r   %  r   zLFusionRotaryEmbeddings.create_rotary_embeddings_from_nodes.<locals>.<lambda>c                    s   | j d  kS r   r   r   )r   r   r   r   &  r   r   r9   r   r=   r   r   r+   )r   r0   rz   r   r   r   r   ri   r   r
   r   r4   r   r   r   r	   r   r   r   r   r   r   rl   rn   r1   r2   r3   )r   rI   r   r   r   r   r   r   r   r   r   r   r   Z	head_sizer   r   r   r   )r   r   r   #create_rotary_embeddings_from_nodes  sR    





z:FusionRotaryEmbeddings.create_rotary_embeddings_from_nodesc                    s8  | j |jkr|jdkrd S d  |jdkrt|jdksD|jd dkrRtd d S | |  d krrtd d S | j| t	t
 fdd| jjjj}t|dkst| jjjj|d	  nB| j|d
ddddgdd	d	d	d	g}| j|d
ddddddddg	dd	d	d	dd	d	d	d	g	}|d ks2|d kr@td d S | j|d
dddgdd	dd	g}| j|d
dddddddgdd	ddd	d	d	d	g}|d ks|d krtd d S |d j|d jks
|d j|d jks
|d j|d jks
|d j|d jkrtd d S | j|d
dgd	d	g}	|	d krHtd d S d\}
}}| j|d
ddddddddg	ddd	d	d	d	dd	d	g	}| j|d
dddddddgddd	d	d	d	dd	g}| j|d
ddddddgddd	d	dd	d	g}| j|d
dddddgddd	d	dd	g}|d k	r(|}
|
d jd	 }n|d k	rF|}
|
d jd	 }nf|d k	rr|}
|
d jd	 }|
d jd }n:|d k	r|}
|
d jd	 }|
d jd }ntd d S d\}}| j|d
ddddddddg	d	dd	d	d	d	dd	d	g	}| j|d
dddddddgd	dd	d	d	d	dd	g}| j|d
ddddddgd	dd	d	dd	d	g}| j|d
dddddgd	dd	d	dd	g}|d k	r|}|d jd	 }n|d k	r|}|d jd	 }nf|d k	r|}|d jd	 }|d jd }n:|d k	r |}|d jd	 }|d jd }ntd d S |dkr| j|
d d gdg}| j|d d gdg}|d ksr|d ksr|d	 j|d	 jkrtd! d S |d	 jd	 }ng }g }d"\}}|
|kr||ks|
|kr||kr|
d# j|d# jks|
d j|d jkrtd$ d S n|
|kr||ks,|
|kr||kr|
d j|d jkrPtd% d S | j|
d ddgdd	g}| j|
d dddgd	d	d	g}|d ks|d ks| j|d jd	 d ks|d jdkrtd& d S n
td' | |d jd	 ||||jd	   d kr td d S | |g | |d d  | |d d  | |d d  | |d d  | |	d d  | |
 | | | |d d  | |d d  |d k	rt| j|d	 dkr| | |d k	r| |d d  | | j  | j| j j< | j  d(| _d S ))Nr   >   r]      r9   >   pos_idsposposition_idr   pos_idzLfuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding functionz=fuse_rotary_embeddings: failed to create RotaryEmbedding nodec                    s   | j  jd kS r   )r*   r   r   r   r   r   r   q  r   z-FusionRotaryEmbeddings.fuse.<locals>.<lambda>r   r>   r8   ZNegr?   r[   r:   rb   r;   r<   z9fuse_rotary_embeddings: failed to match x2 in rotate_halfr=   z9fuse_rotary_embeddings: failed to match x1 in rotate_halfr_   zCfuse_rotary_embeddings: failed to match common input in rotate_halfz8fuse_rotary_embeddings: failed to match x in rotate_half)Nr   r   ZSqueezera   rg   z>fuse_rotary_embeddings: failed to match sin path in apply_rope)Nr   r   rZ   zGfuse_rotary_embeddings: failed to match position ids path in apply_roperc   re   zdfuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cachezRfuse_rotary_embeddings: failed to match common Add node in sin cache and cos cachezKfuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len pathsz:fuse_rotary_embeddings: failed to match common cache pathsT)rz   rh   ri   r   r.   r/   r   rn   rk   r   r   r   r   Z
value_infor-   removerB   r*   Zfind_graph_inputr   r   Zadd_nodes_to_removeZget_childrenr6   rl   rm   rj   ro   )r   r   rp   r^   Zold_shape_inferZrotate_half_x2_path_1Zrotate_half_x2_path_2Zrotate_half_x1_path_1Zrotate_half_x1_path_2Zx_pathZsin_pathr   r   Z
sin_path_1Z
sin_path_2Z
sin_path_3Z
sin_path_4Zcos_pathr   Z
cos_path_1Z
cos_path_2Z
cos_path_3Z
cos_path_4Zposition_ids_from_sin_pathZposition_ids_from_cos_pathZpast_seq_len_pathZcurr_seq_len_pathr   r   r   rq   R  s   























,






$

zFusionRotaryEmbeddings.fuse)rr   rs   rt   r   r   r   r   r   r   rw   r   rq   rx   r   r   r   r   ry     s   L8ry   )loggingtypingr   r   Zfusion_attentionr   Zfusion_baser   Zonnxr   r   r   r	   r
   Z
onnx_modelr   	getLoggerrr   r.   r   ry   r   r   r   r   <module>   s   
       &