U
    hy                     @   sT  d dl Z d dlZd dlZd dlZd dlZd dlZd dlZd dlZd dlZd dl	m
Z
 dddddZdd	d
ddZdd Zd?ddZeeedddZeeeedddZeeeedddZeedddZeedddZeeeeeeeeeeedd d!Zeeeeeeeeeed"
d#d$Zeeeeeeeeeed%
d&d'Zd@eeeeeeeeeeeed*d+d,ZdAeeeeeeeeeeeeed-d.d/ZdBeeeeeeeeeeed0d1d2ZdCeeeeeeeeeeed0d3d4Zeeeeeeeeeed5
d6d7Zd8d9 ZdDd:d;Z d<d= Z!e"d>krPd dl#Z#z
e!  W n$ e$k
rN   e#j%e&   Y nX dS )E    Nmeasure_memoryzrunwayml/stable-diffusion-v1-5zstabilityai/stable-diffusion-2z stabilityai/stable-diffusion-2-1z+stabilityai/stable-diffusion-xl-refiner-1.0)1.5z2.02.1zxl-1.0CUDAExecutionProviderROCMExecutionProviderZMIGraphXExecutionProviderZTensorrtExecutionProvider)cudarocmZmigraphxtensorrtc               
   C   s$   ddddddddd	d
g
} d}| |fS )Nz.a photo of an astronaut riding a horse on marsz@cute grey cat with blue eyes, wearing a bowtie, acrylic paintingzia cute magical flying dog, fantasy art drawn by disney concept artists, highly detailed, digital paintingzdan illustration of a house with large barn with many cute flower pots and beautiful blue sky sceneryzgone apple sitting on a table, still life, reflective, full color photograph, centered, close-up productzWbackground texture of stones, masterpiece, artistic, stunning photo, award winner photozSnew international organic style house, tropical surroundings, architecture, 8k, hdrznbeautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstationzcblue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realisticzldelicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8kz*bad composition, ugly, abnormal, malformed )promptsnegative_promptr   r   ^/tmp/pip-unpacked-wheel-socb9apf/onnxruntime/transformers/models/stable_diffusion/benchmark.pyexample_prompts#   s    r   c                 C   s   t d|| |dS )NT)Zis_gpufuncmonitor_typestart_memoryr   )r   r   r   r   r   r   measure_gpu_memory6   s    r   )
model_name	directorydisable_safety_checkerc           	      C   s   ddl m}m} dd l}|d k	rJtj|s0t| }|j	|||d}n|j	| d|dd}|
|jj|_|jdd |rd |_d |_|S )Nr   )DDIMSchedulerOnnxStableDiffusionPipeline)providerZsess_optionsZonnxT)revisionr   Zuse_auth_tokendisable)	diffusersr   r   onnxruntimeospathexistsAssertionErrorZSessionOptionsfrom_pretrainedfrom_config	schedulerconfigset_progress_bar_configsafety_checkerfeature_extractor)	r   r   r   r   r   r   r   Zsession_optionspiper   r   r   get_ort_pipeline:   s,    r+   )r   r   enable_torch_compileuse_xformersc           	      C   s   ddl m}m} ddlm}m} |j| |dd}|jj|d |rN|	  |rt
|j|_t
|j|_t
|j|_td ||jj|_|jdd	 |rd |_d |_|S )
Nr   )r   StableDiffusionPipeline)channels_lastfloat16torch_dtyper   )Zmemory_formatz)Torch compiled unet, vae and text_encoderTr   )r   r   r.   torchr/   r0   r#   toZunetZ*enable_xformers_memory_efficient_attentioncompileZvaeZtext_encoderprintr$   r%   r&   r'   r(   r)   )	r   r   r,   r-   r   r.   r/   r0   r*   r   r   r   get_torch_pipelineX   s"    r7   )enginer   
batch_sizer   c                 C   s6   | dd dd}|  d| d| |r0dnd S )	N/zstable-diffusion-sd__b Z_safe)splitreplace)r8   r   r9   r   Zshort_model_namer   r   r   get_image_filename_prefixs   s    rB   )r9   image_filename_prefixc
                    sN  ddl m}
 t|
stt \}} fdd}t|	||}t|	||}|  g }t|D ]\}}||krx qt|D ]}t }|g  |g  ddj	}t }|| }|
| td|dd	 t|D ]*\}}|| d
| d
| d
| d qqqbddlm} d| ||t|t| t|||dS )Nr   )r   c                      s   d d d S Nwarm up)num_inference_stepsZnum_images_per_promptr   r   r9   heightr*   stepswidthr   r   warmup   s    z run_ort_pipeline.<locals>.warmup      @)rF   r   guidance_scaleInference took .3f secondsr=   .jpg__version__r   r8   versionrH   rJ   rI   r9   batch_countnum_promptsaverage_latencymedian_latencyfirst_run_memory_MBsecond_run_memory_MB)r   r   
isinstancer"   r   r   	enumeraterangetimeimagesappendr6   saver   rS   sumlen
statisticsmedian)r*   r9   rC   rH   rJ   rI   rW   rV   r   memory_monitor_typer   r   r   rK   first_run_memorysecond_run_memorylatency_listipromptjinference_startr`   inference_endlatencykimageort_versionr   rG   r   run_ort_pipelinex   sR    

(rt   c
                    sJ  t  \}
} fdd}t|	||}t|	||}|  td g }t|
D ]\}}||krh qtj  t|D ]}t }|g  d|g  d dj	}tj  t }|| }|
| td|dd t|D ]*\}}|| d	| d	| d	| d
 qqzqRdtj ||t|t| t|||dS )Nc                      s   d d d S rD   r   r   rG   r   r   rK      s    z"run_torch_pipeline.<locals>.warmupFrL   )rl   rH   rJ   rF   rM   r   	generatorrN   rO   rP   r=   rQ   r3   rT   )r   r   r3   set_grad_enabledr]   r   Zsynchronizer^   r_   r`   ra   r6   rb   rS   rc   rd   re   rf   )r*   r9   rC   rH   rJ   rI   rW   rV   r   rg   r   r   rK   rh   ri   rj   rk   rl   rm   rn   r`   ro   rp   rq   rr   r   rG   r   run_torch_pipeline   sT    





(rw   )r   r   r   r9   r   rH   rJ   rI   rW   rV   tuningc                 C   s   |}|r|dkr|dddf}t   }t| |||}t   }td||  d td| ||}t||||||||	|
|
}|| ||dd|d	d
 |S )N)r   r      )Ztunable_op_enableZtunable_op_tuning_enableModel loading took rP   ZortExecutionProviderr?   Fr   r   r   r   enable_cuda_graph)r_   r+   r6   rB   rt   updaterA   )r   r   r   r9   r   rH   rJ   rI   rW   rV   r   rg   rx   Zprovider_and_options
load_startr*   load_endrC   resultr   r   r   run_ort   s:    
	r   )
rU   r   r9   r   rH   rJ   rI   rW   rV   r}   c           #         s  |dkst ddlm} ddlm} ddlm} || }| }|j|dd}|j||| ||d	| j
d	tjd
 fdd}t|
||	}t|
||	}|  td| |}g }t \}}t|D ]\}}||kr qt|D ]}t }|g  |g  dj}t }|| }|| td|dd t|D ],\} }!|!| d| d| d|  d qNqqddlm}" |d|"|ddj ||t|t| t|||||dS )Nr   r   r   PipelineInfo)&OnnxruntimeCudaStableDiffusionPipeliner%   Z	subfolder)r%   requires_safety_checkerr}   pipeline_infor   r1   c                      s   dg  d d S )NrE   )image_heightimage_widthrF   r   r   rG   r   r   rK   U  s    z"export_and_run_ort.<locals>.warmupZort_cudar   rF   rN   rO   rP   r=   rQ   rR   r   r{   r?   r   r8   rU   r   r   rH   rJ   rI   r9   rV   rW   rX   rY   rZ   r[   r   r}   )r"   r   r   diffusion_modelsr   Zonnxruntime_cuda_txt2imgr   namer#   set_cached_folderr4   r3   r0   r   rB   r   r]   r^   r_   r`   ra   r6   rb   r   rS   rA   
engine_dirrc   rd   re   rf   )#rU   r   r9   r   rH   rJ   rI   rW   rV   r   rg   r}   r   r   r   r   r   r%   rK   rh   ri   rC   rj   r   r   rk   rl   rm   rn   r`   ro   rp   rq   rr   rs   r   rG   r   export_and_run_ort0  st    	


*
r   )
rU   r9   r   rH   rJ   rI   rW   rV   max_batch_sizer}   c           $         s  ddl m} ddlm} ddlm} || }| } |
ks@t|j|dd}|j|dt	j
|| |||
d||d	j|dd
 d fdd}t|	||}t|	||}|  td| |}g }t \}}t|D ]\}}||kr qt|D ]}t }|g  |g  dj}t }|| }|| td|dd t|D ],\} }!|!| d| d| d|  d qVqqddlm}" ddlm}# |d|#d|" dj|| ||t|t| t|||||dS )Nr   r   r   )*OnnxruntimeTensorRTStableDiffusionPipeliner%   r   Zfp16   )
r   r2   r%   r   r   r   r   
onnx_opsetr}   r   )r   r   c                      s   dg  dg  d d S )NrE   negativer   r   r   r9   r*   rI   r   r   rK     s    zrun_ort_trt.<locals>.warmuport_trtr   rN   rO   rP   r=   rQ   rR   r   	tensorrt()r   )r   r   r   r   Zonnxruntime_tensorrt_txt2imgr   r   r"   r#   r3   r0   r   r4   r   rB   r   r]   r^   r_   r`   ra   r6   rb   r
   rS   r   r   rc   rd   re   rf   )$rU   r9   r   rH   rJ   rI   rW   rV   r   rg   r   r}   r   r   r   r   r   r%   rK   rh   ri   rC   rj   r   r   rk   rl   rm   rn   r`   ro   rp   rq   rr   trt_versionrs   r   r   r   run_ort_trt  s    


*
r   FT)work_dirrU   r9   r   rH   rJ   rI   rW   rV   r   nvtx_profileuse_cuda_graphc           .         s\  t d ddlm} |   |ks&tddlm} ||}| }ddlm}m	} ddl
m} |j}|| ||\}}}}}||d|d d|||||d	
jj|||d
 ddddtj d    fdd}t|
||	}t|
||	}|  td| |}g }t \} }!t| D ]\}"}#|"|kr< qt|D ]}$t }%j|#g  |!g  dddd\}&}'|&}&t }(|(|% })||) t d|)dd|'dd t|&D ],\}*}+|+| d|" d|$ d|* d qqDq$  ddlm}, ddl m}- |! d|-d|, d| ||t"|t#| t$%|||||dS )Nzd[I] Initializing ORT TensorRT EP accelerated StableDiffusionXL txt2img pipeline (static input shape)r   init_trt_pluginsr   
EngineTypeget_engine_pathsTxt2ImgPipelineDDIMF	r%   
output_dirhf_tokenverboser   r   r   framework_model_direngine_typer   Topt_image_heightopt_image_widthopt_batch_sizeZforce_engine_rebuildstatic_batchZstatic_image_shapeZmax_workspace_sizeZ	device_idc                      s&   j dg  dg  dd d S NrE   r   T)denoising_stepsrK   runr   r9   rH   pipelinerI   rJ   r   r   rK   5  s         z"run_ort_trt_static.<locals>.warmupr   rL   {   r   guidanceseedrK   End2End took rO    seconds. Inference latency: .1f msr=   rQ   rR   r   r   r   r   )&r6   trt_utilitiesr   r"   r   r   
short_nameengine_builderr   r   pipeline_txt2imgr   ORT_TRTbackendbuild_enginesr3   r   current_deviceload_resourcesr   rB   r   r]   r^   r_   r   to_pil_imagera   rb   teardownr
   rS   r   r   rc   rd   re   rf   ).r   rU   r9   r   rH   rJ   rI   rW   rV   r   rg   r   r   r   r   r   r   r   r   r   r   r   onnx_dirr   r   r   r=   rK   rh   ri   rC   rj   r   r   rk   rl   rm   rn   r`   pipeline_timero   rp   rq   rr   r   rs   r   r   r   run_ort_trt_static  s    




.
r   )r   rU   r   r9   r   rH   rJ   rI   rW   rV   r   r   r   c           1         sr  t d ddlm} ddlm} |   |ks2tddlm} ||}ddlm	}m
} ddlm} |j}|| ||\}}}}}||d|d d	||d
|d	jj|||d d	d	d	d
d
d	d	d	|d d tj j }||\}}j|    fdd}t|||
} t|||
}!|  td| |}"g }#t \}$}%t|$D ]\}&}'|&|krr q,t|	D ]}(t })j|'g  |%g  ddd
d\}*}+|*}*t },|,|) }-|#|- t d|-dd|+dd t|*D ],\}.}/|/|" d|& d|( d|. d qqzqZ  dd l }0d|0j!d |	|t"|#t#|# t$%|#| |!|dS )N][I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)r   cudartr   r   r   r   r   FT)r%   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   Zforce_exportZforce_optimizeZforce_buildr   Zstatic_shapeZenable_refitZenable_previewZenable_all_tacticstiming_cacheZonnx_refit_dirc                      s&   j dg  dg  dd d S r   r   r   r   r   r   rK     s         z#run_tensorrt_static.<locals>.warmuptrtrL   r   r   r   rO   r   r   r   r=   rQ   r
   default)r8   rU   r   rH   rJ   rI   r9   rV   rW   rX   rY   rZ   r[   r}   )&r6   r   r   r   r   r"   r   r   r   r   r   r   r   TRTr   load_enginesmaxmax_device_memory
cudaMallocactivate_enginesr   r   rB   r   r]   r^   r_   r   r   ra   rb   r   r
   rS   rc   rd   re   rf   )1r   rU   r   r9   r   rH   rJ   rI   rW   rV   r   rg   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r=   shared_device_memoryrK   rh   ri   rC   rj   r   r   rk   rl   rm   rn   r`   r   ro   rp   rq   rr   r   r   r   r   run_tensorrt_staticz  s      




.r   )r   rU   r9   r   rH   rJ   rI   rW   rV   r   r   c           ,         s  t d dd l}ddlm} ddlm} d dksHd dkr^td d d|  ksptdd	lm	} dd
l
m m  	f	dd}ddlm} ddlm} ||}|||||dd}|||tj j }||\}}j| j|   d#fdd	

fdd}t|
||	}t|
||	}|  | }td||}g }t \} }!t| D ]\}"}#|"|kr qt|D ]}$t }%	r|  
|#g |!g ddd\}&}'	r|   |&}&t }(|(|% })|!|) t d|)dd|'dd t|&D ],\}*}+|+"| d|" d|$ d|* d q\qΐq#  #  |d |j$d!||t%|t&| t'(|||d"S )$Nr   r   r   r      CImage height and width have to be divisible by 8 but specified as:  and .r   r   c           	         sj    j }||\}}}}}| |d|d d||d
}|jj|||ddddddddd|d d |S )Nr   Fr   r   Tr   )r   r   r   )	pipeline_classr   r   r   r   r   r   r   r   	r   r9   r   rH   r   r   r   rJ   r   r   r   init_pipeline3  sL      z-run_tensorrt_static_xl.<locals>.init_pipelineImg2ImgXLPipelineTxt2ImgXLPipelineTZ
is_refinerFc                    sL    j | |d||dd	\}}j | ||d||d	\}}||| fS Ng      @Zlatent)r   r   rK   r   return_type)r   r   rK   r   r   rl   r   r   rK   r`   Z	time_baseZtime_refiner	demo_basedemo_refinerr   r   rI   r   r   run_sd_xl_inferencep  s.    

z3run_tensorrt_static_xl.<locals>.run_sd_xl_inferencec                      s   dg  dg  dd d S NrE   r   T)rK   r   r   r9   r   r   r   rK     s    z&run_tensorrt_static_xl.<locals>.warmupr   r   r   rK   r   rO   r   r   r   r=   .pngr
   r   r   r8   rU   r   rH   rJ   rI   r9   rV   rW   rX   rY   rZ   r[   r}   )NF))r6   r
   r   r   r   r   
ValueErrorr"   r   r   r   r   r   pipeline_img2img_xlr   pipeline_txt2img_xlr   r   r   r   r   r   r   r   r   rB   r   r]   r^   r_   cudaProfilerStartcudaProfilerStopr   ra   rb   r   rS   rc   rd   re   rf   ),r   rU   r9   r   rH   rJ   rI   rW   rV   r   rg   r   r   r   r   r   r   r   r   r   r   base_pipeline_inforefiner_pipeline_infor   r=   r   rK   rh   ri   r   rC   rj   r   r   rk   rl   rm   rn   r`   r   ro   rp   rq   rr   r   r   r9   r   r   r   rH   r   r   r   r   r   rI   r   rJ   r   r   run_tensorrt_static_xl  s    *



   

.r   c           *         s  ddl m} d dks,d dkrBtd d dksNtddlm m  	f	dd	}dd
lm} ddl	m
} ddlm} ||}|||||dd}|||  d$fdd	

fdd}t|
||	}t|
||	}|  | }td||}g }t \}}t|D ]\}}||krh qBt|D ]}t } 	r|  
|g |g ddd\}!}"	r|  |!}!t }#|#|  }$||$ td|$dd|"dd t|!D ]:\}%}&| d| d| d|% d}'|&|' td|' qqpqP    ddlm}( ddlm}) |d |)d!|( d"||t|t | t!"|||d#S )%Nr   r   r   r   r   r   r   c           	         sf    j }||\}}}}}| |d|d d||d
}|jj|||dddddtj d |S )Nr   Fr   r   Tr   r   )r   r   r   r3   r   r   )	r   r   r   r   r   r   r   r=   r   r   r   r   r     sB      z%run_ort_trt_xl.<locals>.init_pipeliner   r   r   Tr   Fc                    sL    j | |d||dd	\}}j | ||d||d	\}}||| fS r   r   r   r   r   r   r     s.    

z+run_ort_trt_xl.<locals>.run_sd_xl_inferencec                      s   dg  dg  dd d S r   r   r   r   r   r   rK   /  s    zrun_ort_trt_xl.<locals>.warmupr   r   r   r   rO   r   r   r   r=   r   zImage saved torR   r   r
   r   r   )NF)#r   r   r   r"   r   r   r   r   r   r   r   r   r   r   r   r   rB   r   r]   r^   r_   r   r   r   ra   r6   rb   r   r
   rS   r   rc   rd   re   rf   )*r   rU   r9   r   rH   rJ   rI   rW   rV   r   rg   r   r   r   r   r   r   r   r   r   r   rK   rh   ri   r   rC   rj   r   r   rk   rl   rm   rn   r`   r   ro   rp   rq   rr   filenamer   rs   r   r   r   run_ort_trt_xl  s    %



   



r  )
r   r9   r   r,   r-   rH   rJ   rI   rW   rV   c                 C   s   dt jj_dt jj_t d t }t| |||}t }td||  d t	d| ||}|st 
   t||||||||	|
|
}W 5 Q R X nt||||||||	|
|
}|| d |rdn
|rdnd|dd	 |S )
NTFrz   rP   r3   r5   Zxformersr   r|   )r3   backendsZcudnnZenabledZ	benchmarkrv   r_   r7   r6   rB   Zinference_moderw   r~   )r   r9   r   r,   r-   rH   rJ   rI   rW   rV   r   rg   r   r*   r   rC   r   r   r   r   	run_torchq  sV    



	r  c                  C   s  t  } | jdddtddddgdd | jd	d
dtdtt dd | jddddd | jdddttt ddd | jdddtd dd | jdddtddd | jdddd d! | jdd" | jd#ddd$d! | jdd% | jd&ddd'd! | jdd( | jd)d*t	d+d+d,d-d.d/d0d1d2gd3d4 | jd5dt	d6d7d | jd8dt	d6d9d | jd:d;dt	d<d=d | jd>d?dt	d+d@d | jdAdBdt	t
d+dCdDdEd | jdFdGdt	t
d+d1d.dHd | jdIdJdddKd! | jddL |  }|S )MNz-ez--engineFr   r3   r
   z-Engines to benchmark. Default is onnxruntime.)requiredtyper   choiceshelpz-rz
--providerr   z8Provider to benchmark. Default is CUDAExecutionProvider.z-tz--tuning
store_truezsEnable TunableOp and tuning. This will incur longer warmup latency, and is mandatory for some operators of ROCm EP.)actionr  z-vz	--versionr   z>Stable diffusion version like 1.5, 2.0 or 2.1. Default is 1.5.)r  r  r  r   r  z-pz
--pipelinez[Directory of saved onnx pipeline. It could be the output directory of optimize_pipeline.py.)r  r  r   r  z-wz
--work_dirr   z?Root directory to save exported onnx models, built engines etc.z--enable_safety_checkerzEnable safety checker)r  r	  r  )enable_safety_checkerz--enable_torch_compilez#Enable compile unet for PyTorch 2.0)r,   z--use_xformerszUse xformers for PyTorch)r-   z-bz--batch_sizery            r   
          z)Number of images per batch. Default is 1.)r  r   r  r  z--heighti   z$Output image height. Default is 512.z--widthz#Output image width. Default is 512.z-sz--steps2   zNumber of steps. Default is 50.z-nz--num_promptsz Number of prompts. Default is 1.z-cz--batch_count      z(Number of batches to test. Default is 5.z-mz--max_trt_batch_sizezdMaximum batch size for TensorRT. Change the value may trigger TensorRT engine rebuild. Default is 4.z-gz--enable_cuda_graphz/Enable Cuda Graph. Requires onnxruntime >= 1.16)r}   )argparseArgumentParseradd_argumentstrlist	PROVIDERSkeys	SD_MODELSset_defaultsintr^   
parse_args)parserargsr   r   r   parse_arguments  s   




					

r!  c                    sL   dd l }|t }| D ]( | r<t fdddD rt j qd S )Nr   c                 3   s   | ]}| j kV  qd S )N)r    ).0xlibr   r   	<genexpr>X  s     z)print_loaded_libraries.<locals>.<genexpr>)ZlibcuZlibnvr
   )psutilProcessr   getpidZmemory_mapsanyr6   r    )Zcuda_related_onlyr'  pr   r$  r   print_loaded_librariesS  s
    r,  c                  C   sd  t  } t|  | jdkr| jdkr,dtjd< ddlm} ddlm} |	||	dkrbdtjd	< | j
r| jdkr| jd
kr| jd kstd|	||	dk rtdtjdd | jdkrdnd}t|d }td| t| j }t| j }| jdkr| jdkrd| jkrVtd t| j| j| jd| j| j| j| j| j||| jd| j
d}n| jrtd| j
rpdnd t| j| j| j | j| j| j| j| j||| j| j
d}nDtd t | j| j| j| j | j| j| j| j| j||| jd| j
d}n| jdkrl|dkrl| jd krltd| j
r,d nd! t!| j|| j| j | j| j| j| j| j||| j
d"}nr| jdkr| jrtj"#| jst$d#td$| d%| j  t%|| j|| j| j | j| j| j| j| j||| jd&}n| jdkr@d| jkr@td' t&| j| j| jd| j| j| j| j| j||| jd| j
d}n| jdkrtd( t'| j| j|| jd| j| j| j| j| j||| jd| j
d)}nNtd*| j( d+| j) d, t*|| j| j | j(| j)| j| j| j| j| j||d-}t| t+d.d/d0d1L}d2d3d4d5d6d7d8d9d:d;d<d=d>d?d@dAdBg}	t,j-||	dC}
|
.  |
/| W 5 Q R X | jdDkr`t0| jd
k d S )ENr   )r   1ZORT_DISABLE_TRT_FLASH_ATTENTIONr   )rU   rR   z1.16.0Z!ORT_ENABLE_FUSED_CAUSAL_ATTENTION)r   r
   z:The stable diffusion pipeline does not support CUDA graph.z1.16z.CUDA graph requires ONNX Runtime 1.16 or laterz%(funcName)20s: %(message)s)fmtr	   r   z&GPU memory used before loading models:r
   ZxlzNTesting Txt2ImgXLPipeline with static input shape. Backend is ORT TensorRT EP.TF)r   rU   r9   r   rH   rJ   rI   rW   rV   r   rg   r   r   r   z;Testing OnnxruntimeTensorRTStableDiffusionPipeline with {}.zstatic input shapezdynamic batch size)rU   r9   r   rH   rJ   rI   rW   rV   r   rg   r   r}   zLTesting Txt2ImgPipeline with static input shape. Backend is ORT TensorRT EP.r   z[Testing OnnxruntimeCudaStableDiffusionPipeline with {} input shape. Backend is ORT CUDA EP.ZstaticZdynamic)rU   r   r9   r   rH   rJ   rI   rW   rV   r   rg   r}   z?--pipeline should be specified for the directory of ONNX modelsz/Testing diffusers StableDiffusionPipeline with z provider and tuning=)r   r   r   r9   r   rH   rJ   rI   rW   rV   r   rg   rx   zGTesting Txt2ImgXLPipeline with static input shape. Backend is TensorRT.zETesting Txt2ImgPipeline with static input shape. Backend is TensorRT.)r   rU   r   r9   r   rH   rJ   rI   rW   rV   r   rg   r   r   r   zNTesting Txt2ImgPipeline with dynamic input shape. Backend is PyTorch: compile=z, xformers=r   )r   r9   r   r,   r-   rH   rJ   rI   rW   rV   r   rg   zbenchmark_result.csvar?   )modenewliner   r   r8   rU   r   r   rH   rJ   rI   r9   rV   rW   rX   rY   rZ   r[   r}   )
fieldnamesry   )1r!  r6   r8   rU   r   environ	packagingr   rS   parser}   r   r   r   coloredlogsinstallr   r  r  r  r   r9   rH   rJ   rI   rW   rV   Zmax_trt_batch_sizerx   formatr   r
  r   r   r    isdirr"   r   r   r   r,   r-   r  opencsv
DictWriterwriteheaderwriterowr,  )r   rU   rs   rg   r   Zsd_modelr   r   Zcsv_fileZcolumn_namesZ
csv_writerr   r   r   main\  s   








"
r?  __main__)N)FT)FT)FT)FT)T)'r  r;  r   re   sysr_   __init__r6  r3   Zbenchmark_helperr   r  r  r   r   r  boolr+   r7   r  rB   rt   rw   r   r   r   r   r   r   r  r  r!  r,  r?  __name__	traceback	Exceptionprint_exceptionexc_infor   r   r   r   <module>   s2  
ED3]p      !   G   -B !
	 ]

