U
    h                     @   s  d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZ d dlZd dlZd dlmZmZmZmZ d dlZddlmZmZmZ G dd	 d	ZG d
d dZG dd deZG dd de jdZG dd dZG dd deZ G dd deZ!G dd de!Z"G dd de!Z#G dd de!Z$G dd de jdZ%G dd  d e%Z&dd!ej'd"i fee(ef e
ee(  d#d$d%Z)dS )&    N)Enum)Path)DictOptionalSequenceTupleUnion)
ModelProtoTensorProtohelpernumpy_helper   )
apply_plotload_model_with_shape_infersmooth_distributionc                   @   sB   e Zd ZedddddddgZdd	 Zed
d Zedd ZdS )
TensorDataavgstdlowesthighesthist
hist_edgesbinsc                 K   sD   |  D ]6\}}|tjkr2td|dtj dt| || qd S )NzUnexpected value z not in .)itemsr   _allowed
ValueErrorsetattr)selfkwargskv r"   F/tmp/pip-unpacked-wheel-socb9apf/onnxruntime/quantization/calibrate.py__init__   s    
zTensorData.__init__c                 C   s4   t | drt | ds(tdt|  d| j| jfS )Nr   r   z0Attributes 'lowest' and/or 'highest' missing in r   )hasattrAttributeErrordirr   r   r   r"   r"   r#   range_value!   s    zTensorData.range_valuec                 C   s4   t | drt | ds(tdt|  d| j| jfS )Nr   r   z)Attributes 'avg' and/or 'std' missing in r   )r%   r&   r'   r   r   r(   r"   r"   r#   avg_std'   s    zTensorData.avg_stdN)	__name__
__module____qualname__	frozensetr   r$   propertyr)   r*   r"   r"   r"   r#   r      s   
r   c                   @   sR   e Zd Zeeeeef f dddZdd Z	dd Z
dd	 Zd
d Zdd ZdS )TensorsDatadatac              	   C   s   || _ i | _| D ]\}}t|ts:tdt| dt|tr|tj	krvt
|dkrvt|d |d d| j|< qt
|dkrt|d |d |d |d d	| j|< qtd
|ddt
| d| dt|tstdt| d|| j|< qd S )NzKeys must be strings not r      r   r   )r   r         )r   r   r   r   zUnexpected tuple for rz	, it has z elements: zValues must be TensorData not )calibration_methodr2   r   
isinstancestr	TypeErrortypetupleCalibrationMethodMinMaxlenr   )r   r7   r2   r    r!   r"   r"   r#   r$   /   s     

&"
zTensorsData.__init__c                 c   s   | j E d H  d S Nr1   r(   r"   r"   r#   __iter__A   s    zTensorsData.__iter__c                 C   s
   || j kS r@   r1   r   keyr"   r"   r#   __contains__D   s    zTensorsData.__contains__c                 C   s
   | j | S r@   r1   rB   r"   r"   r#   __getitem__G   s    zTensorsData.__getitem__c                 C   s(   || j krtd|d|| j |< d S )Nz)Only an existing tensor can be modified, z is not.)r2   RuntimeError)r   rC   valuer"   r"   r#   __setitem__J   s    
zTensorsData.__setitem__c                 C   s
   | j  S r@   )r2   valuesr(   r"   r"   r#   rI   O   s    zTensorsData.valuesN)r+   r,   r-   r   r9   r   r   r   r$   rA   rD   rE   rH   rI   r"   r"   r"   r#   r0   .   s   r0   c                   @   s   e Zd ZdZdZdZdZdS )r=   r   r   r3   r5   N)r+   r,   r-   r>   Entropy
PercentileDistributionr"   r"   r"   r#   r=   S   s   r=   c                   @   s<   e Zd Zedd ZejedddZdd Z	dd	 Z
d
S )CalibrationDataReaderc                 C   s   t |drt|jptS )Nget_next)r%   callablerN   NotImplemented)clssubclassr"   r"   r#   __subclasshook__[   s    z&CalibrationDataReader.__subclasshook__returnc                 C   s   t dS )z9generate the input data dict for ONNXinferenceSession runNNotImplementedErrorr(   r"   r"   r#   rN   _   s    zCalibrationDataReader.get_nextc                 C   s   | S r@   r"   r(   r"   r"   r#   rA   d   s    zCalibrationDataReader.__iter__c                 C   s   |   }|d krt|S r@   )rN   StopIteration)r   resultr"   r"   r#   __next__g   s    zCalibrationDataReader.__next__N)r+   r,   r-   classmethodrS   abcabstractmethoddictrN   rA   rZ   r"   r"   r"   r#   rM   Z   s   
rM   )	metaclassc                   @   s~   e Zd Zdeeef eee  dddZdgfdd	Z	d
d Z
edddZdd Zdd ZedddZedddZdS )CalibraterBaseNaugmented_model.onnxF
model_pathop_types_to_calibratec                 C   sh   t |trtt|| _nt |tr0t|| _ntd|| _|| _|| _|| _	d| _
d| _dg| _dS )a  
        :param model_path: ONNX model to calibrate. It should be a model file path
        :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
        :param augmented_model_path: save augmented model to this path.
        :param symmetric: make range of tensor symmetric (central point is 0).
        :param use_external_data_format: use external data format to store model which size is >= 2Gb
        z model_path should be model path.NCPUExecutionProvider)r8   r9   r   r   modelr   rd   augmented_model_path	symmetricuse_external_data_formatZaugment_modelinfer_sessionexecution_providers)r   rc   rd   rg   rh   ri   r"   r"   r#   r$   o   s    

zCalibraterBase.__init__re   c                 C   s   || _ |   dS )zz
        reset the execution providers to execute the collect_data. It triggers to re-creating inference session.
        N)rk   create_inference_session)r   rk   r"   r"   r#   set_execution_providers   s    z&CalibraterBase.set_execution_providersc                 C   s,   t  }t jj|_t j| j|| jd| _dS )z9
        create an OnnxRuntime InferenceSession.
        )sess_optionsZ	providersN)	onnxruntimeZSessionOptionsZGraphOptimizationLevelZORT_DISABLE_ALLZgraph_optimization_levelZInferenceSessionrg   rk   rj   )r   rn   r"   r"   r#   rl      s    
z'CalibraterBase.create_inference_sessionrf   c           	      C   s   dd |j jD }|dd |j jD  |dd |j jD  dd |j jD }t }tjh}|j j	D ]h}| j
r|j| j
krjt|j|jD ]@}||kr|| }|jdr|jjj|kr||kr|| qqj||fS )z
        select input/output tensors of candidate nodes to calibrate.
        returns:
            tensors (set): set of tensor name.
            value_infos (dict): tensor name to value info.
        c                 S   s   i | ]}|j |qS r"   name).0vir"   r"   r#   
<dictcomp>   s      z>CalibraterBase.select_tensors_to_calibrate.<locals>.<dictcomp>c                 S   s   i | ]}|j |qS r"   rq   )rs   Zotr"   r"   r#   ru      s      c                 S   s   i | ]}|j |qS r"   rq   )rs   itr"   r"   r#   ru      s      c                 S   s   h | ]
}|j qS r"   rq   )rs   initr"   r"   r#   	<setcomp>   s     z=CalibraterBase.select_tensors_to_calibrate.<locals>.<setcomp>tensor_type)graphZ
value_infoupdateoutputinputinitializersetr
   FLOATnoderd   Zop_type	itertoolschainr;   ZHasFieldry   Z	elem_typeadd)	r   rf   value_infosr~   tensors_to_calibrateZtensor_type_to_calibrater   tensor_namert   r"   r"   r#   select_tensors_to_calibrate   s&    
z*CalibraterBase.select_tensors_to_calibratec                 C   s   | j S )zP
        return: augmented onnx model. Call after calling augment_graph
        rp   r(   r"   r"   r#   get_augment_model   s    z CalibraterBase.get_augment_modelc                 C   s   t dS )z
        abstract method: augment the input model to prepare for collecting data. It will:
            1. augment the model to be able to collect desired statistics data
            2. save augmented model to augmented_model_paths
        NrV   r(   r"   r"   r#   augment_graph   s    zCalibraterBase.augment_graphdata_readerc                 C   s   t dS )z
        abstract method: collect the tensors that will be used for range computation. It can be called multiple times.
        NrV   )r   r   r"   r"   r#   collect_data   s    zCalibraterBase.collect_datarT   c                 C   s   t dS )ze
        abstract method: compute data based on the calibration method stored in TensorsData
        NrV   r(   r"   r"   r#   compute_data   s    zCalibraterBase.compute_data)Nra   FF)r+   r,   r-   r   r9   r   r   r   r$   rm   rl   r	   r   r   r   rM   r   r0   r   r"   r"   r"   r#   r`   n   s       

r`   c                       sj   e Zd Zdeeef eee  d fddZdd	 Z	d
d Z
edddZdd ZedddZ  ZS )MinMaxCalibraterNra   F{Gz?rb   c                    st   t  j|||||d g | _d| _t| jjj| _dd | jjjD | _	|| _
|rj|dk sb|dkrjtd|| _dS )a  
        :param model_path: ONNX model to calibrate. It is a model path
        :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
        :param augmented_model_path: save augmented model to this path.
        :param symmetric: make range of tensor symmetric (central point is 0).
        :param use_external_data_format: use external data format to store model which size is >= 2Gb
        :param moving_average: compute the moving average of the minimum and maximum values instead of the global minimum and maximum.
        :param averaging_constant: constant smoothing factor to use when computing the moving average.
        rd   rg   rh   ri   Nc                 S   s   h | ]
}|j qS r"   rq   rs   r|   r"   r"   r#   rx      s     z,MinMaxCalibrater.__init__.<locals>.<setcomp>r   r   z;Invalid averaging constant, which should not be < 0 or > 1.)superr$   intermediate_outputscalibrate_tensors_ranger?   rf   rz   r|   num_model_outputsmodel_original_outputsmoving_averager   averaging_constant)r   rc   rd   rg   rh   ri   r   r   	__class__r"   r#   r$      s    zMinMaxCalibrater.__init__c                    s    j\}}tt  ttjdgtj	d }jj
j|  fdd}|D ]}||d ||d qXtjjjjd dS )	z
        Adds ReduceMin and ReduceMax nodes to all quantization_candidates op type nodes in
        model and ensures their outputs are stored as part of the graph output
        :return: augmented ONNX model
        r   Zdtypec                    s   d}| d | }|d }t jj|| g|g||d}t jjd| g|g|d}jjj||g jjjt	|t
jdg d S )Nr   _Z_Reshape)keepdimsrr   ZReshape)inputsoutputsrr   )onnxr   Z	make_noderf   rz   r   extendr|   appendZmake_tensor_value_infor
   r   )r   Zreduce_op_namer   Zreduce_outputintermediate_outputZreduce_nodeZreshape_nodeZreshape_shape_namer   r"   r#   add_reduce_min_max  s$        z:MinMaxCalibrater.augment_graph.<locals>.add_reduce_min_maxZ	ReduceMinZ	ReduceMaxZsave_as_external_dataN)r   rf   r9   uuidZuuid4r   Z
from_arraynparrayint64rz   r~   r   r   saverg   ri   )r   Ztensorsr   Zreshape_shaper   tensorr"   r   r#   r      s    
zMinMaxCalibrater.augment_graphc                 C   s
   g | _ d S r@   r   r(   r"   r"   r#   clear_collected_data(  s    z%MinMaxCalibrater.clear_collected_datar   c                 C   sn   |  }|sq&| j| jd | q t| jdkr<td|  }t|t	sbt
dt| d|   d S )Nr   No data is collected.z+compute_data must return a TensorsData not r   )rN   r   r   rj   runr?   r   r   r8   r0   r:   r;   r   )r   r   r   tr"   r"   r#   r   +  s    
zMinMaxCalibrater.collect_datac                 C   s   |s|S |  D ]\}}| jrd|d | j|| d |d    }|d | j|| d |d    }n,t|d || d }t|d || d }||f||< q|S )Nr   r   )r   r   r   minmax)r   Z	old_rangeZ	new_rangerC   rG   	min_value	max_valuer"   r"   r#   merge_range:  s    "$zMinMaxCalibrater.merge_rangerT   c                    s  t jdkrjS fddtt jd D fddjD }i |D ](}| D ]\}}|g | q\qPjd   fddtdt  dD }fdd	D }g }tdt  dD ]}d}	d}
jrt	j
| |  dd
}t	j
| |d   dd
}n$t| |  }t| |d   }t|tsP|jdkrXt|}	t|tsp|jdkrxt|}
jrtt|	t|
}|t| |g q|t|	|
g qttjtt||}jrj|_n|_jS )z
        Compute the min-max range of tensor
        :return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs }
        r   c                    s   g | ]} j  | jqS r"   rj   get_outputsrr   rs   ir(   r"   r#   
<listcomp>R  s     z1MinMaxCalibrater.compute_data.<locals>.<listcomp>c                    s   g | ]}t t |qS r"   r^   ziprs   r   output_namesr"   r#   r   S  s    Nc                    s   g | ]} |  d d qS )r   r   )
rpartitionr   )added_output_namesr"   r#   r   \  s    r3   c                    s    i | ]}|j kr| | qS r"   )r   r   )merged_output_dictr   r"   r#   ru   `  s    
  z1MinMaxCalibrater.compute_data.<locals>.<dictcomp>)Zaxisr   )r?   r   r   ranger   
setdefaultr   r   r   r   Zmeanr   r   r8   intsizefloatrh   absr<   r0   r=   r>   r^   r   r   )r   output_dicts_listdr    r!   Zcalibrate_tensor_namesZmerged_added_output_dictpairsr   r   r   Zmin_value_arrayZmax_value_arrayZmax_absolute_valueZnew_calibrate_tensors_ranger"   )r   r   r   r   r#   r   I  sN     

zMinMaxCalibrater.compute_data)Nra   FFFr   )r+   r,   r-   r   r9   r   r   r   r$   r   r   rM   r   r   r0   r   __classcell__r"   r"   r   r#   r      s         

#+r   c                	       sb   e Zd Zdeeef eee  d	 fd
dZdd Z	dd Z
edddZedddZ  ZS )HistogramCalibraterNra   F
percentile      -X@samerb   c                    sv   t  j|||||d g | _d| _t| jjj| _dd | jjjD | _	d| _
|| _|| _|| _|	| _d| _|
| _dS )a=  
        :param model_path: ONNX model to calibrate. It is a model path.
        :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
        :param augmented_model_path: save augmented model to this path.
        :param use_external_data_format: use external data format to store model which size is >= 2Gb
        :param method: A string. One of ['entropy', 'percentile'].
        :param symmetric: make range of tensor symmetric (central point is 0).
        :param num_bins: number of bins to create a new histogram for collecting tensor values.
        :param num_quantized_bins: number of quantized bins. Default 128.
        :param percentile: A float number between [0, 100]. Default 99.99.
        :param scenario: see :class:`DistributionCalibrater`
        r   Nc                 S   s   h | ]
}|j qS r"   rq   r   r"   r"   r#   rx     s     z/HistogramCalibrater.__init__.<locals>.<setcomp>)r   r$   r   r   r?   rf   rz   r|   r   r   	collectormethodnum_binsnum_quantized_binsr   r   scenario)r   rc   rd   rg   ri   r   rh   r   r   r   r   r   r"   r#   r$     s$    zHistogramCalibrater.__init__c                 C   sV   |  | j\| _}| jD ]"}|| jkr| jjj||  qtj| j| j	| j
d dS )z
        make all quantization_candidates op type nodes as part of the graph output.
        :return: augmented ONNX model
        r   N)r   rf   r   r   rz   r|   r   r   r   rg   ri   )r   r   r   r"   r"   r#   r     s    

z!HistogramCalibrater.augment_graphc                 C   s
   g | _ d S r@   r   r(   r"   r"   r#   r     s    z(HistogramCalibrater.clear_collected_datar   c                    s   |  }|sq&jjd| q tjdkr<tdfddttjd D fddjD }i  |D ](}| D ]\}} 	|g | qqx fdd D }j
stjjjjjjd	_
j
|   dS )
zy
        Entropy Calibrator collects operators' tensors as well as generates tensor histogram for each operator.
        Nr   r   c                    s   g | ]} j  | jqS r"   r   r   r(   r"   r#   r     s     z4HistogramCalibrater.collect_data.<locals>.<listcomp>c                    s   g | ]}t t |qS r"   r   r   r   r"   r#   r     s    c                    s    i | ]}|j kr| | qS r"   )r   r   )merged_dictr   r"   r#   ru     s     
  z4HistogramCalibrater.collect_data.<locals>.<dictcomp>)r   rh   r   r   r   r   )rN   r   r   rj   r   r?   r   r   r   r   r   HistogramCollectorr   rh   r   r   r   r   collectr   )r   r   r   r   r   r    r!   Zclean_merged_dictr"   )r   r   r   r#   r     s4     
z HistogramCalibrater.collect_datarT   c                 C   sh   | j stdt| tr tj}n8t| tr2tj}n&t| trDtj	}nt
dt|  dt|| j  S )z
        Compute the min-max range of tensor
        :return: dictionary mapping: {tensor name: (min value, max value)}
        z9No collector created and can't generate calibration data.zUnknown calibrater z". This method must be overwritten.)r   r   r8   EntropyCalibraterr=   rJ   PercentileCalibraterrK   DistributionCalibraterrL   r:   r;   r0   compute_collection_result)r   calr"   r"   r#   r     s    


z HistogramCalibrater.compute_data)	Nra   Fr   Fr   r   r   r   )r+   r,   r-   r   r9   r   r   r   r$   r   r   rM   r   r0   r   r   r"   r"   r   r#   r     s"            

,&r   c                       s6   e Zd Zd	eeef eee  d fddZ  Z	S )
r   Nra   Fentropyr   rb   c	           	   
      s    t  j||||||||d dS )a  
        :param model_path: ONNX model to calibrate. It is a model path
        :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
        :param augmented_model_path: save augmented model to this path.
        :param use_external_data_format: use external data format to store model which size is >= 2Gb
        :param method: A string. One of ['entropy', 'percentile', 'distribution'].
        :param symmetric: make range of tensor symmetric (central point is 0).
        :param num_bins: number of bins to create a new histogram for collecting tensor values.
        :param num_quantized_bins: number of quantized bins. Default 128.
        )r   rh   r   r   Nr   r$   )	r   rc   rd   rg   ri   r   rh   r   r   r   r"   r#   r$     s    zEntropyCalibrater.__init__)Nra   Fr   Fr   r   
r+   r,   r-   r   r9   r   r   r   r$   r   r"   r"   r   r#   r     s          

r   c                       s6   e Zd Zd
eeef eee  d fdd	Z  Z	S )r   Nra   Fr   r   r   rb   c	           	   
      s    t  j||||||||d dS )a  
        :param model_path: ONNX model to calibrate. It is a model path
        :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
        :param augmented_model_path: save augmented model to this path.
        :param use_external_data_format: use external data format to store model which size is >= 2Gb
        :param method: A string. One of ['entropy', 'percentile', 'distribution'].
        :param symmetric: make range of tensor symmetric (central point is 0).
        :param num_quantized_bins: number of quantized bins. Default 128.
        :param percentile: A float number between [0, 100]. Default 99.99.
        )r   rh   r   r   Nr   )	r   rc   rd   rg   ri   r   rh   r   r   r   r"   r#   r$     s    zPercentileCalibrater.__init__)Nra   Fr   Fr   r   r   r"   r"   r   r#   r     s          

r   c                       s6   e Zd Zd
eeef eee  d fdd	Z  Z	S )r   Nra   Fdistributionr   r   rb   c              	      s   t  j|||||||d dS )a  
        :param model_path: ONNX model to calibrate. It is a model path
        :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors.
        :param augmented_model_path: save augmented model to this path.
        :param use_external_data_format: use external data format to store model which size is >= 2Gb
        :param method: A string. One of ['entropy', 'percentile', 'distribution'].
        :param symmetric: make range of tensor symmetric (central point is 0).
        :param num_bins: number of bins to create a new histogram for collecting tensor values.
        :param scenario: for float 8 only, if `scenario="same"`,
            the algorithm weights and float 8 follow the same distribution,
            if `scenario="p3"`, it assumes the weights follow
            a gaussian law and float 8 ~ X^3 where X is a gaussian law
        )r   r   r   Nr   )r   rc   rd   rg   ri   r   r   r   r   r"   r#   r$   @  s    zDistributionCalibrater.__init__)Nra   Fr   r   r   r   r"   r"   r   r#   r   ?  s         

r   c                   @   s,   e Zd ZdZejdd Zejdd ZdS )CalibrationDataCollectorzL
    Base class for collecting data for calibration-based quantization.
    c                 C   s   t dS )z
        Generate informative data based on given data.
            name_to_arr : dict
                tensor name to NDArray data
        NrV   r   name_to_arrr"   r"   r#   r   g  s    z CalibrationDataCollector.collectc                 C   s   t dS )z?
        Get the optimal result among collection data.
        NrV   r(   r"   r"   r#   r   p  s    z2CalibrationDataCollector.compute_collection_resultN)r+   r,   r-   __doc__r\   r]   r   r   r"   r"   r"   r#   r   b  s
   
r   c                   @   sv   e Zd ZdZdd Zdd Zdd Zdd	 Zd
d Zdd Z	dd Z
dd Zdd ZedddZdd Zdd ZdS )r   a`  
    Collecting histogram for each tensor. Percentile and Entropy method are supported.

    ref: https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
    ref: https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/_modules/
                 pytorch_quantization/calib/histogram.html
    c                 C   s.   i | _ || _|| _|| _|| _|| _|| _d S r@   )histogram_dictr   rh   r   r   r   r   )r   r   rh   r   r   r   r   r"   r"   r#   r$     s    zHistogramCollector.__init__c                 C   s   | j S r@   )r   r(   r"   r"   r#   get_histogram_dict  s    z%HistogramCollector.get_histogram_dictc                 C   sN   t d | jdkr| |S | jdkrB| jr6| |S | |S ntdd S )Nz/Collecting tensor data and making histogram ...>   r   r   r   DOnly 'entropy', 'percentile' or 'distribution' methods are supported)printr   collect_valuerh   collect_absolute_valuer   r   r"   r"   r#   r     s    



zHistogramCollector.collectc                 C   sP  |  D ]@\}}t|}| }|jdkrDt|}t|}nd}d}t|}|| jkrtj	|| j
d\}}||||f| j|< q| j| }|d }	|d }
|d }|d }t|}||d kr|d |d  }t|d | || |}t||f}tj	||d\}}|dt|  |7  < ||t|	|t|
|f| j|< qdS )z5
        Collect histogram on absolute value
        r   )r   r3   r5   r   N)r   r   asarrayflattenr   r   r   absoluter   	histogramr   ZarangeZhstackr?   )r   r   r   data_arrr   r   r   r   old_histogramold_minold_maxold_histold_hist_edgesZ	temp_amaxwidthZnew_bin_edgesr"   r"   r#   r     s2    






z)HistogramCollector.collect_absolute_valuec           
      C   s   |  D ]\}}t|}| }|jdkrBt|}t|}nd}d}tt|t|}|| jkr| j| }| 	|||||| j|< qtj
|| j| |fd\}}	||	|||f| j|< qdS )z1
        Collect histogram on real value
        r   r   N)r   r   r   r   r   r   r   r   r   merge_histogramr   r   )
r   r   r   r   r   r   	thresholdr   r   r   r"   r"   r#   r     s2    




    z HistogramCollector.collect_valuec                 C   s  |\}}}}	}
||
krRt j|t||
 |
fd\}}|| |t||t|	||
fS |
dkrt j|t|| |fd\}}||7 }nrt|}d|
 | }t||
 | d }|d|  }|| |
 }t j||| |fd\}}||||   |7  < ||t||t|	||fS d S )Nr   r   r3   r   )r   r   r?   r   r   r   )r   r   r   Znew_minZnew_maxZnew_thresholdr   r   r   r   Zold_thresholdZnew_histr   r   r   Zold_num_binsZ
old_strideZhalf_increased_binsZnew_num_binsr"   r"   r#   r     s2    
z"HistogramCollector.merge_histogramc                 C   sp   | j rt| j dkrtdtd| j d | jdkr@|  S | jdkrR|  S | jdkrd|  S tdd S )	Nr   z=Histogram has not been collected. Please run collect() first.z0Finding optimal threshold for each tensor using z algorithm ...r   r   r   r   )r   r?   r   r   r   compute_entropycompute_percentilecompute_distributionr(   r"   r"   r#   r     s    


z,HistogramCollector.compute_collection_resultc                 C   s  | j dk s| j dkrtd| j}| j }i }tdt|  td| j  tdd|  d| d	 | D ]8\}}|d }|d
 }| }t	|| }	| j
rt|	|d }
t||
  t||
 f||< nDd| d }t|	d| }
t|	|}t|| t||
 f||< |d }|d }|| d |k rP||| d
 f||< || d
 |krv|| d |f||< || |d d ||< tjdddkrpt|| qp|S )Nr   d   z<Invalid percentile. Must be in range 0 <= percentile <= 100.Number of tensors : Number of histogram bins : zPercentile : (g      Y@,)r   g      i@g      ?r3   r5   QUANTIZATION_DEBUGr   1)r   r   r   r   r?   r   r   sumr   Zcumsumrh   Zsearchsortedr   osenvirongetr   )r   r   r   thresholds_dictr   r   r   r   totalZcdfZ	idx_rightZpercent_to_cut_one_sideZidx_leftr   r   r"   r"   r#   r     sD    



z%HistogramCollector.compute_percentilec                 C   s   | j }| j}i }tdt|  td| j td| j  | D ]T\}}| ||}|||< ||d d ||< tj	
dddkrJt|d |d  qJ|S )	Nr  zWNumber of histogram bins : {} (The number may increase depends on the data it collects)zNumber of quantized bins : r3   r  r   r  r   )r   r   r   r?   formatr   r   get_entropy_thresholdr  r  r  r   )r   r   r   r  r   r   optimal_thresholdr"   r"   r#   r   =  s"    z"HistogramCollector.compute_entropyr   c                 C   sX  |dkrt d| d|d d |dd   d }|dkr|| |  |   }| |d   |   |d  d }||fS t||krt|d dkr| ||   |   }| || | d   |   d }||fS t|| }d|t|< d|t|< t|| | }| |  |   }| |d   |   |d  d }||fS )Nr   zpower=z <= 0 is invalid.r   r   g      ?r3   )r   r
  r   r   r   isnanisinf)r   r   powerrI   r   r   Zfactr"   r"   r#   _avg_stdV  s$    $$$zHistogramCollector._avg_stdc           	      C   s   | j dk rtd| j}i }tdt|  td| j   td| jd | D ]\}}|d }|d }| jd	kr| j||dd
\}}n(| jdkr| j||dd
\}}ntdt||||d||< t	j
dddkrXt|| qX|S )Ni   z3Invalid num_bins. Must be in range 512 <= num_bins.r  r  zScenario : r  r   r   r   )r  Zp3gUUUUUU?z,Invalid scenario. Must be in {'same', 'p3'}.)r   r   r   r   r  r  )r   r   r   r   r?   r   r   r  r   r  r  r  r   )	r   r   r  r   r   r   r   Zavg_coefZstd_coefr"   r"   r#   r  l  s&    


z'HistogramCollector.compute_distributionc                 C   s  ddl }ddlm} |d }|d }|j}|d }|d }	t||	 d }
dd t|
jD }t|	|d dD ]}|| }|| d |kr|| d n|}t|| t|| f|||	 < |||| }|  }t	|d| }t	||d }|d  |7  < |d  |7  < |dk
tj}tj|tjd	}|j| }t|D ]*}|| }|| }t	||| ||< qJ|d  t	||| d 7  < tj|jtjd	}t|D ]L}|| }|| }t	||| }|dkrt|| t| |||< qt|}t|}t|tjr0||||
||	 < qptd
|
||	 < qpt|
}|| }|d }|d }|d |k r~||d f}|d |kr|d |f}|S )aF  Given a dataset, find the optimal threshold for quantizing it.
        The reference distribution is `q`, and the candidate distribution is `p`.
        `q` is a truncated version of the original distribution.
        Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
        r   N)r   r   r3   c                 S   s   g | ]}d qS ))r   r   r"   r   r"   r"   r#   r     s     z<HistogramCollector.get_entropy_threshold.<locals>.<listcomp>r   r   infr5   )copyZscipy.statsr   r   r   zerosr   r   deepcopyr
  Zastyper   r   r8   ZndarrayZargmin)r   r   r   r  r   r   r   r   Zzero_bin_indexZnum_half_quantized_binZkl_divergenceZ
thresholdsr   start_indexZ	end_indexZsliced_distributionpZleft_outliers_countZright_outliers_countZnonzerosZquantized_binsZnum_merged_binsindexstartendqZnormZmin_kl_divergence_idxr  r   r   r"   r"   r#   r    sd     


 
 
z(HistogramCollector.get_entropy_thresholdN)r   )r+   r,   r-   r   r$   r   r   r   r   r   r   r   r   staticmethodr  r  r  r"   r"   r"   r#   r   x  s   	%  .r   ra   F)rf   rd   c              	   C   s  d }|t jkrdd|krdn|d }d|kr.dn|d }d|krBdn|d }	t| ||||||	d}n|t jkrd|krzdn|d }
d	|krdn|d	 }d|krdn|d }t| |||||
|d
}n|t jkr$d|krdn|d }
d|krdn|d }d|krdn|d }t| |||||
|d}nL|t jkrpd|kr>dn|d }
d|krTdn|d }t| ||||
|d}|r|	  |
  |S td| d S )Nrh   Fr   r   r   )ri   rh   r   r   r   r   r   )ri   rh   r   r   r   r   r   T)ri   rh   r   r   r   r   )ri   r   r   zUnsupported calibration method )r=   r>   r   rJ   r   rK   r   rL   r   r   rl   r   )rf   rd   rg   Zcalibrate_methodri   Zextra_optionsZ
calibratorrh   r   r   r   r   r   r   r"   r"   r#   create_calibrator  sp    

	
	
	r"  )*r\   r   r  r   enumr   pathlibr   typingr   r   r   r   r   Znumpyr   r   r	   r
   r   r   ro   Zquant_utilsr   r   r   r   r0   r=   ABCMetarM   r`   r   r   r   r   r   r   r   r>   r9   r"  r"   r"   r"   r#   <module>   sF   %k *y""#  t

