U
    }hf                     @   s  d dl Z d dlZd dlZd dlZd dlmZmZ d dlZd dlZ	d dl
mZ d dlmZmZ d dlmZmZmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZ d d
lm Z m!Z!m"Z"m#Z# d dl$m%Z% d dl&m'Z' d dl(m)Z)m*Z* d dl+m,Z,m-Z-m.Z. d dl/m0Z0m1Z1 d dl2m3Z3m4Z4m5Z5m6Z6m7Z7m8Z8 d dl9m:Z:m;Z;m<Z<m=Z=m>Z>m?Z?m@Z@mAZAmBZBmCZCmDZDmEZEmFZFmGZGmHZHmIZImJZJmKZKmLZLmMZMmNZNmOZOmPZPmQZQ d dlRmSZS d dlTmUZUmVZVmWZW G dd deXZYG dd deeZZG dd deZ[G dd deZ\G dd deZ]G dd deZ^G d d! d!eZ_G d"d# d#eZ`G d$d% d%eZaG d&d' d'eZbG d(d) d)eZcG d*d+ d+eZZdG d,d- d-eZZeG d.d/ d/eZZfG d0d1 d1eZgG d2d3 d3eZZhG d4d5 d5eZiG d6d7 d7eZjG d8d9 d9eZkG d:d; d;eZZlG d<d= d=eZmG d>d? d?eZnG d@dA dAeZoG dBdC dCe#ZpG dDdE dEepZqG dFdG dGeZrG dHdI dIe ZsG dJdK dKe ZtG dLdM dMe!ZuG dNdO dOeZvG dPdQ dQeZwdRdS ZxdTdU ZydVdW ZzdXdY Z{dZd[ Z|d\d] Z}d^d_ Z~d`da Zdbdc Zddde Zdfdg Zdhdi Zdjdk Zdldm ZG dndo doeeZdpdq Zdrds Zdtdu Zdvdw Zdxdy Zdzd{ Zed|krpe  d}d~ Zdd Zdd Zdd Zdd Zdd Zdd ZdS )    N)IntegralReal)config_context
get_config)BaseEstimatorClassifierMixinOutlierMixin)MiniBatchKMeans)make_multilabel_classification)PCA)ExtraTreesClassifier)ConvergenceWarningSkipTestWarning)LinearRegressionLogisticRegressionMultiTaskElasticNetSGDClassifier)GaussianMixture)KNeighborsRegressor)SVCNuSVC)
_array_apiall_estimators
deprecated)Interval
StrOptions)MinimalClassifierMinimalRegressorMinimalTransformerSkipTestignore_warningsraises)_NotAnArray_set_checking_parameters_yield_all_checkscheck_array_api_input-check_class_weight_balanced_linear_classifier"check_classifier_data_not_an_array<check_classifiers_multilabel_output_format_decision_function2check_classifiers_multilabel_output_format_predict8check_classifiers_multilabel_output_format_predict_proba(check_dataframe_column_names_consistency check_decision_proba_consistencycheck_estimator%check_estimator_get_tags_default_keyscheck_estimators_unfittedcheck_fit_check_is_fittedcheck_fit_score_takes_y%check_methods_sample_order_invariancecheck_methods_subset_invariancecheck_no_attributes_set_in_initcheck_outlier_contaminationcheck_outlier_corruption!check_regressor_data_not_an_arraycheck_requires_y_noneset_random_state)available_if)check_arraycheck_is_fitted	check_X_yc                   @   s   e Zd ZdZdS )CorrectNotFittedErrorzException class to raise if estimator is used before fitting.

    Like NotFittedError, it inherits from ValueError, but not from
    AttributeError. Used for testing only.
    N)__name__
__module____qualname____doc__ rC   rC   M/tmp/pip-unpacked-wheel-ig1s1lm8/sklearn/utils/tests/test_estimator_checks.pyr>   G   s   r>   c                   @   s   e Zd Zdd Zdd ZdS )BaseBadClassifierc                 C   s   | S NrC   selfXyrC   rC   rD   fitP   s    zBaseBadClassifier.fitc                 C   s   t |jd S Nr   nponesshaperH   rI   rC   rC   rD   predictS   s    zBaseBadClassifier.predictNr?   r@   rA   rK   rR   rC   rC   rC   rD   rE   O   s   rE   c                   @   s(   e Zd Zd	ddZd
ddZdd ZdS )ChangesDictr   c                 C   s
   || _ d S rF   )key)rH   rU   rC   rC   rD   __init__X   s    zChangesDict.__init__Nc                 C   s   |  ||\}}| S rF   _validate_datarG   rC   rC   rD   rK   [   s    zChangesDict.fitc                 C   s   t |}d| _t|jd S )Ni  r   )r;   rU   rN   rO   rP   rQ   rC   rC   rD   rR   _   s    zChangesDict.predict)r   )N)r?   r@   rA   rV   rK   rR   rC   rC   rC   rD   rT   W   s   

rT   c                   @   s    e Zd ZdddZdddZdS )	SetsWrongAttributer   c                 C   s
   || _ d S rF   )acceptable_key)rH   rZ   rC   rC   rD   rV   f   s    zSetsWrongAttribute.__init__Nc                 C   s   d| _ | ||\}}| S rL   wrong_attributerX   rG   rC   rC   rD   rK   i   s    zSetsWrongAttribute.fit)r   )Nr?   r@   rA   rV   rK   rC   rC   rC   rD   rY   e   s   
rY   c                   @   s    e Zd ZdddZdddZdS )	ChangesWrongAttributer   c                 C   s
   || _ d S rF   )r\   )rH   r\   rC   rC   rD   rV   p   s    zChangesWrongAttribute.__init__Nc                 C   s   d| _ | ||\}}| S N   r[   rG   rC   rC   rD   rK   s   s    zChangesWrongAttribute.fit)r   )Nr]   rC   rC   rC   rD   r^   o   s   
r^   c                   @   s   e Zd ZdddZdS )ChangesUnderscoreAttributeNc                 C   s   d| _ | ||\}}| S r_   )Z_good_attributerX   rG   rC   rC   rD   rK   z   s    zChangesUnderscoreAttribute.fit)Nr?   r@   rA   rK   rC   rC   rC   rD   ra   y   s   ra   c                       s0   e Zd Zd	ddZ fddZd
ddZ  ZS )RaisesErrorInSetParamsr   c                 C   s
   || _ d S rF   prH   re   rC   rC   rD   rV      s    zRaisesErrorInSetParams.__init__c                    s6   d|kr(| d}|dk r"td|| _t jf |S )Nre   r   zp can't be less than 0)pop
ValueErrorre   super
set_paramsrH   kwargsre   	__class__rC   rD   rj      s    
z!RaisesErrorInSetParams.set_paramsNc                 C   s   |  ||\}}| S rF   rW   rG   rC   rC   rD   rK      s    zRaisesErrorInSetParams.fit)r   )Nr?   r@   rA   rV   rj   rK   __classcell__rC   rC   rm   rD   rc      s   
rc   c                   @   s$   e Zd Ze fddZdddZdS )HasMutableParametersc                 C   s
   || _ d S rF   rd   rf   rC   rC   rD   rV      s    zHasMutableParameters.__init__Nc                 C   s   |  ||\}}| S rF   rW   rG   rC   rC   rD   rK      s    zHasMutableParameters.fit)N)r?   r@   rA   objectrV   rK   rC   rC   rC   rD   rq      s   rq   c                   @   s,   e Zd ZdedefddZdddZdS )HasImmutableParameters*   c                 C   s   || _ || _|| _d S rF   )re   qr)rH   re   ru   rv   rC   rC   rD   rV      s    zHasImmutableParameters.__init__Nc                 C   s   |  ||\}}| S rF   rW   rG   rC   rC   rD   rK      s    zHasImmutableParameters.fit)N)r?   r@   rA   rN   Zint32rr   rV   rK   rC   rC   rC   rD   rs      s   rs   c                       s0   e Zd Zd	ddZ fddZd
ddZ  ZS )"ModifiesValueInsteadOfRaisingErrorr   c                 C   s
   || _ d S rF   rd   rf   rC   rC   rD   rV      s    z+ModifiesValueInsteadOfRaisingError.__init__c                    s2   d|kr$| d}|dk rd}|| _t jf |S )Nre   r   )rg   re   ri   rj   rk   rm   rC   rD   rj      s    
z-ModifiesValueInsteadOfRaisingError.set_paramsNc                 C   s   |  ||\}}| S rF   rW   rG   rC   rC   rD   rK      s    z&ModifiesValueInsteadOfRaisingError.fit)r   )Nro   rC   rC   rm   rD   rw      s   
rw   c                       s0   e Zd Zd
ddZ fddZddd	Z  ZS )ModifiesAnotherValuer   method1c                 C   s   || _ || _d S rF   )ab)rH   rz   r{   rC   rC   rD   rV      s    zModifiesAnotherValue.__init__c                    s>   d|kr0| d}|| _|d kr0| d d| _t jf |S )Nrz   r{   Zmethod2)rg   rz   r{   ri   rj   )rH   rl   rz   rm   rC   rD   rj      s    

zModifiesAnotherValue.set_paramsNc                 C   s   |  ||\}}| S rF   rW   rG   rC   rC   rD   rK      s    zModifiesAnotherValue.fit)r   ry   )Nro   rC   rC   rm   rD   rx      s   
	rx   c                   @   s   e Zd Zdd ZdS )NoCheckinPredictc                 C   s   |  ||\}}| S rF   rW   rG   rC   rC   rD   rK      s    zNoCheckinPredict.fitNrb   rC   rC   rC   rD   r|      s   r|   c                   @   s   e Zd Zdd Zdd ZdS )NoSparseClassifierc                 C   s.   | j ||ddgd\}}t|r*td| S )Ncsrcsc)accept_sparseNonsensical Error)rX   spissparserh   rG   rC   rC   rD   rK      s    
zNoSparseClassifier.fitc                 C   s   t |}t|jd S rL   r;   rN   rO   rP   rQ   rC   rC   rD   rR      s    zNoSparseClassifier.predictNrS   rC   rC   rC   rD   r}      s   r}   c                   @   s   e Zd Zdd Zdd ZdS )CorrectNotFittedErrorClassifierc                 C   s&   |  ||\}}t|jd | _| S r_   )rX   rN   rO   rP   coef_rG   rC   rC   rD   rK      s    z#CorrectNotFittedErrorClassifier.fitc                 C   s    t |  t|}t|jd S rL   )r<   r;   rN   rO   rP   rQ   rC   rC   rD   rR      s    z'CorrectNotFittedErrorClassifier.predictNrS   rC   rC   rC   rD   r      s   r   c                   @   s   e Zd ZdddZdd ZdS )NoSampleWeightPandasSeriesTypeNc                 C   s:   | j ||dddd\}}ddlm} t||r6td| S )Nr~   r   Tr   multi_output	y_numericr   Seriesz>Estimator does not accept 'sample_weight'of type pandas.Series)rX   pandasr   
isinstancerh   )rH   rI   rJ   sample_weightr   rC   rC   rD   rK      s        

z"NoSampleWeightPandasSeriesType.fitc                 C   s   t |}t|jd S rL   r   rQ   rC   rC   rD   rR      s    z&NoSampleWeightPandasSeriesType.predict)NrS   rC   rC   rC   rD   r      s   
r   c                   @   s   e Zd ZdddZdd ZdS )BadBalancedWeightsClassifierNc                 C   s
   || _ d S rF   )class_weight)rH   r   rC   rC   rD   rV      s    z%BadBalancedWeightsClassifier.__init__c                 C   sV   ddl m} ddlm} | |}|j}|| j||d}| jdkrL|d7 }|| _| S )Nr   )LabelEncoder)compute_class_weight)classesrJ   Zbalanced      ?)Zsklearn.preprocessingr   sklearn.utilsr   rK   classes_r   r   )rH   rI   rJ   r   r   Zlabel_encoderr   r   rC   rC   rD   rK      s    
z BadBalancedWeightsClassifier.fit)Nr]   rC   rC   rC   rD   r      s   
r   c                   @   s   e Zd ZdddZdd ZdS )BadTransformerWithoutMixinNc                 C   s   |  |}| S rF   rW   rG   rC   rC   rD   rK     s    
zBadTransformerWithoutMixin.fitc                 C   s   t |}|S rF   )r;   rQ   rC   rC   rD   	transform  s    z$BadTransformerWithoutMixin.transform)N)r?   r@   rA   rK   r   rC   rC   rC   rD   r     s   
r   c                   @   s   e Zd Zdd Zdd ZdS )NotInvariantPredictc                 C   s   | j ||dddd\}}| S Nr   Tr   rW   rG   rC   rC   rD   rK     s        
zNotInvariantPredict.fitc                 C   s6   t |}|jd dkr&t|jd S t|jd S )Nr   r`   )r;   rP   rN   rO   zerosrQ   rC   rC   rD   rR   $  s    zNotInvariantPredict.predictNrS   rC   rC   rC   rD   r     s   r   c                   @   s   e Zd Zdd Zdd ZdS )NotInvariantSampleOrderc                 C   s"   | j ||dddd\}}|| _| S r   )rX   _XrG   rC   rC   rD   rK   -  s        
zNotInvariantSampleOrder.fitc                 C   sX   t |}ttj|ddtj| jddrH|| jk rHt|jd S |d d df S )Nr   )Zaxis)r;   rN   Zarray_equivsortr   anyr   rP   rQ   rC   rC   rD   rR   5  s     zNotInvariantSampleOrder.predictNrS   rC   rC   rC   rD   r   ,  s   r   c                   @   s,   e Zd ZdZd
ddZdddZdd	 ZdS )OneClassSampleErrorClassifierzoClassifier allowing to trigger different behaviors when `sample_weight` reduces
    the number of classes to 1.Fc                 C   s
   || _ d S rF   )raise_when_single_class)rH   r   rC   rC   rD   rV   E  s    z&OneClassSampleErrorClassifier.__init__Nc                 C   s   t ||dddd\}}d| _tj|dd\| _}| jjd }|dk rX| jrXd| _td|d k	rt|tj	rt
|dkrtt||}|dk rd| _td	| S )
Nr   Tr   F)Zreturn_inverser      znormal class errorr   )r=   has_single_class_rN   uniquer   rP   r   rh   r   ZndarraylenZcount_nonzeroZbincount)rH   rI   rJ   r   Z
n_classes_rC   rC   rD   rK   H  s(        
z!OneClassSampleErrorClassifier.fitc                 C   s6   t |  t|}| jr&t|jd S t|jd S rL   )r<   r;   r   rN   r   rP   rO   rQ   rC   rC   rD   rR   ^  s
    z%OneClassSampleErrorClassifier.predict)F)Nr?   r@   rA   rB   rV   rK   rR   rC   rC   rC   rD   r   A  s   

r   c                   @   s   e Zd Zdd ZdS )!LargeSparseNotSupportedClassifierc                 C   s~   | j ||ddddd\}}t|rz| dkrR|jjdksH|jjdkrztdn(| dkrzd|jj|j	jfkszt
d| S )N)r~   r   cooT)r   Zaccept_large_sparser   r   r   int64z(Estimator doesn't support 64-bit indices)r   r~   )rX   r   r   Z	getformatrowdtypecolrh   indicesZindptrAssertionErrorrG   rC   rC   rD   rK   g  s(    


z%LargeSparseNotSupportedClassifier.fitNrb   rC   rC   rC   rD   r   f  s   r   c                   @   s(   e Zd ZdddZd	ddZdd ZdS )
SparseTransformerNc                 C   s   |  |j| _| S rF   )rX   rP   X_shape_rG   rC   rC   rD   rK   ~  s    zSparseTransformer.fitc                 C   s   |  |||S rF   )rK   r   rG   rC   rC   rD   fit_transform  s    zSparseTransformer.fit_transformc                 C   s.   t |}|jd | jd kr$tdt|S )Nr`   zBad number of features)r;   rP   r   rh   r   
csr_matrixrQ   rC   rC   rD   r     s    zSparseTransformer.transform)N)N)r?   r@   rA   rK   r   r   rC   rC   rC   rD   r   }  s   

r   c                   @   s   e Zd Zdd Zdd ZdS )EstimatorInconsistentForPandasc                 C   sl   z<ddl m} t||r&|jd | _nt|}|d | _| W S  tk
rf   t|}|d | _|  Y S X d S )Nr   )	DataFrame)r   r   )r`   r   )r   r   r   Zilocvalue_r;   ImportError)rH   rI   rJ   r   rC   rC   rD   rK     s    


z"EstimatorInconsistentForPandas.fitc                 C   s    t |}t| jg|jd  S rL   )r;   rN   arrayr   rP   rQ   rC   rC   rD   rR     s    z&EstimatorInconsistentForPandas.predictNrS   rC   rC   rC   rD   r     s   r   c                       s,   e Zd Zd fdd	Zd fdd	Z  ZS )UntaggedBinaryClassifierNc                    s.   t  ||||| t| jdkr*td| S )Nr   Only 2 classes are supported)ri   rK   r   r   rh   )rH   rI   rJ   Z	coef_initZintercept_initr   rm   rC   rD   rK     s    zUntaggedBinaryClassifier.fitc                    s.   t  j||||d t| jdkr*td| S )N)rI   rJ   r   r   r   r   )ri   partial_fitr   r   rh   )rH   rI   rJ   r   r   rm   rC   rD   r     s    z$UntaggedBinaryClassifier.partial_fit)NNN)NN)r?   r@   rA   rK   r   rp   rC   rC   rm   rD   r     s   r   c                   @   s   e Zd Zdd ZdS )TaggedBinaryClassifierc                 C   s   ddiS )Nbinary_onlyTrC   rH   rC   rC   rD   
_more_tags  s    z!TaggedBinaryClassifier._more_tagsNr?   r@   rA   r   rC   rC   rC   rD   r     s   r   c                       s   e Zd Z fddZ  ZS )EstimatorMissingDefaultTagsc                    s   t    }|d= |S )N	allow_nan)ri   	_get_tagscopy)rH   tagsrm   rC   rD   r     s    z%EstimatorMissingDefaultTags._get_tags)r?   r@   rA   r   rp   rC   rC   rm   rD   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )RequiresPositiveXRegressorc                    s6   | j ||dd\}}|dk  r(tdt ||S )NTr   r   z negative X values not supported!rX   r   rh   ri   rK   rG   rm   rC   rD   rK     s    zRequiresPositiveXRegressor.fitc                 C   s   ddiS )NZrequires_positive_XTrC   r   rC   rC   rD   r     s    z%RequiresPositiveXRegressor._more_tagsr?   r@   rA   rK   r   rp   rC   rC   rm   rD   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )RequiresPositiveYRegressorc                    s6   | j ||dd\}}|dk r(tdt ||S )NTr   r    negative y values not supported!r   rG   rm   rC   rD   rK     s    zRequiresPositiveYRegressor.fitc                 C   s   ddiS )NZrequires_positive_yTrC   r   rC   rC   rD   r     s    z%RequiresPositiveYRegressor._more_tagsr   rC   rC   rm   rD   r     s   r   c                       s$   e Zd Z fddZdd Z  ZS )PoorScoreLogisticRegressionc                    s   t  |d S r_   )ri   decision_functionrQ   rm   rC   rD   r     s    z-PoorScoreLogisticRegression.decision_functionc                 C   s   ddiS )NZ
poor_scoreTrC   r   rC   rC   rD   r     s    z&PoorScoreLogisticRegression._more_tags)r?   r@   rA   r   r   rp   rC   rC   rm   rD   r     s   r   c                   @   s   e Zd Zdd Zdd ZdS )PartialFitChecksNamec                 C   s   |  || | S rF   rW   rG   rC   rC   rD   rK     s    zPartialFitChecksName.fitc                 C   s&   t | d }| j|||d d| _| S )N_fitted)resetT)hasattrrX   r   )rH   rI   rJ   r   rC   rC   rD   r     s    z PartialFitChecksName.partial_fitN)r?   r@   rA   rK   r   rC   rC   rC   rD   r     s   r   c                   @   s    e Zd ZdZdd Zdd ZdS )BrokenArrayAPIz=Make different predictions when using Numpy and the Array APIc                 C   s   | S rF   rC   rG   rC   rC   rD   rK     s    zBrokenArrayAPI.fitc                 C   s@   t  d }t|\}}|r,|dddgS tdddgS d S )NZarray_api_dispatchr`   r      )r   r   Zget_namespaceZasarrayrN   r   )rH   rI   ZenabledZxp_rC   rC   rD   rR     s
    
zBrokenArrayAPI.predictN)r?   r@   rA   rB   rK   rR   rC   rC   rC   rD   r     s   r   c                	   C   s   zt d W n tk
r*   tdY nX zt d W n tk
rV   tdY nX ttdd tdt dd W 5 Q R X d S )	NZarray_api_compatz-array_api_compat is required to run this testznumpy.array_apiz,numpy.array_api is required to run this testNot equal to tolerancematchr   )Zarray_namespace)	importlibimport_moduleModuleNotFoundErrorr   r!   r   r%   r   rC   rC   rC   rD   test_check_array_api_input  s      r   c               	   C   sH   t td} d}tt|d t|  W 5 Q R X t| d sDtd S )N
   z&Don't want to call array_function sum!r   )r"   rN   rO   r!   	TypeErrorsumZmay_share_memoryr   )Z	not_arraymsgrC   rC   rD    test_not_an_array_array_function  s
    r   c                  C   s    G dd dt } td|   d S )Nc                   @   s   e Zd Zeddd ZdS )zbtest_check_fit_score_takes_y_works_on_deprecated_fit.<locals>.TestEstimatorWithDeprecatedFitMethodz=Deprecated for the purpose of testing check_fit_score_takes_yc                 S   s   | S rF   rC   rG   rC   rC   rD   rK     s    zftest_check_fit_score_takes_y_works_on_deprecated_fit.<locals>.TestEstimatorWithDeprecatedFitMethod.fitN)r?   r@   rA   r   rK   rC   rC   rC   rD   $TestEstimatorWithDeprecatedFitMethod  s   r   test)r   r1   )r   rC   rC   rD   4test_check_fit_score_takes_y_works_on_deprecated_fit  s    r   c               	   C   s  d} t t| d tt W 5 Q R X d} tt  t t| d tt  W 5 Q R X d} t t| d tt  W 5 Q R X tj	dd}tt
  W 5 Q R X tdd |D kstt t| d tt  W 5 Q R X d	} t t| d tt  W 5 Q R X d
} t t| d tt  W 5 Q R X z6ddlm} d} t t| d tt  W 5 Q R X W n tk
rl   Y nX d} t t| d tt  W 5 Q R X d} t t| d tt  W 5 Q R X d} t t| d tt  W 5 Q R X tt  d} t t| d tt  W 5 Q R X tj}d}dj||d} t t| d tt  W 5 Q R X tj}d}dj||d} t t| d tt  W 5 Q R X tj}d| } t t| d tt  W 5 Q R X tj}| d} t t| d tt  W 5 Q R X d} t t| d tt   W 5 Q R X d} t t| d tt!  W 5 Q R X tt"  tt#  tt#dd tt$  tt%  tt&  d} t t| d tt'  W 5 Q R X tt(  d S )NzPassing a class was deprecatedr   zXParameter 'p' of estimator 'HasMutableParameters' is of type object which is not allowedz>get_params result does not match what was passed to set_paramsTrecordc                 S   s   g | ]
}|j qS rC   category.0ZrecrC   rC   rD   
<listcomp>2  s     z(test_check_estimator.<locals>.<listcomp>zobject has no attribute 'fit'Did not raiser   r   zkEstimator NoSampleWeightPandasSeriesType raises error if 'sample_weight' parameter is of type pandas.SerieszCEstimator NoCheckinPredict doesn't check for NaN and inf in predictz)Estimator changes __dict__ during predictzrEstimator ChangesWrongAttribute should not change or mutate  the parameter wrong_attribute from 0 to 1 during fit.zEstimator adds public attribute\(s\) during the fit method. Estimators are only allowed to add private attributes either started with _ or ended with _ but wrong_attribute addedrR   zY{method} of {name} is not invariant when applied to a datasetwith different sample order.)methodnamez={method} of {name} is not invariant when applied to a subset.z;Estimator %s doesn't seem to fail gracefully on sparse datazu failed when fitted on one label after sample_weight trimming. Error message is not explicit, it should have 'class'.ztEstimator LargeSparseNotSupportedClassifier doesn't seem to support \S{3}_64 matrix, and is not failing gracefully.*r   g{Gz?)Cr   ))r!   r   r-   rr   rs   r   rq   rw   warningscatch_warningsrc   UserWarningrx   AttributeErrorr   rE   r   r   rh   r   r   r|   rT   r^   ra   rY   r   r?   formatr   r}   r   r   r   r   r   r   r   r   r   r   )r   recordsr   r   r   rC   rC   rD   test_check_estimator  s    

  




r   c               	   C   sT   t ddddg} tt tdd|  W 5 Q R X t ddddg} tdd|  d S )Ng        r   g      ?       @r`   r   )rN   r   r!   r   r6   )ZdecisionrC   rC   rD   test_check_outlier_corruption  s
    
r   c                	   C   s$   t td tt  W 5 Q R X d S )Nz.*fit_transform.*)r!   r   r-   r   rC   rC   rC   rD   )test_check_estimator_transformer_no_mixin  s    r   c               
   C   s   ddl m}  |  }ttttttfD ]}tt	d. | }t
| t| t|}t| W 5 Q R X |t|ksvttt	dB | }t
| t| ||jd |j t|}t| W 5 Q R X |t|ks"tq"d S )Nr   )	load_irisr   r   )sklearn.datasetsr   r   r   r   r   r   r	   r    r   r#   r9   joblibhashr-   r   rK   datatarget)r   Ziris	EstimatorestZold_hashrC   rC   rD   test_check_estimator_clones  s0    	

r  c               	   C   s8   d} t t| d tdt  W 5 Q R X tdt  d S )Nr   r   	estimator)r!   r   r/   r}   r   r   rC   rC   rD   test_check_estimators_unfitted  s    r  c               	   C   s   G dd dt } G dd dt }G dd dt }d}tt|d td	|   W 5 Q R X d
}tt|d td	|  W 5 Q R X td	|  tdd td	| jdd W 5 Q R X d S )Nc                   @   s   e Zd Zdd ZdS )zNtest_check_no_attributes_set_in_init.<locals>.NonConformantEstimatorPrivateSetc                 S   s
   d | _ d S rF   )Zyou_should_not_set_this_r   rC   rC   rD   rV     s    zWtest_check_no_attributes_set_in_init.<locals>.NonConformantEstimatorPrivateSet.__init__Nr?   r@   rA   rV   rC   rC   rC   rD    NonConformantEstimatorPrivateSet  s   r  c                   @   s   e Zd ZdddZdS )zNtest_check_no_attributes_set_in_init.<locals>.NonConformantEstimatorNoParamSetNc                 S   s   d S rF   rC   )rH   Zyou_should_set_this_rC   rC   rD   rV     s    zWtest_check_no_attributes_set_in_init.<locals>.NonConformantEstimatorNoParamSet.__init__)Nr  rC   rC   rC   rD    NonConformantEstimatorNoParamSet  s   r  c                   @   s   e Zd ZddiZdS )zOtest_check_no_attributes_set_in_init.<locals>.ConformantEstimatorClassAttributefooTN)r?   r@   rA   Z9_ConformantEstimatorClassAttribute__metadata_request__fitrC   rC   rC   rD   !ConformantEstimatorClassAttribute  s   r
  zEstimator estimator_name should not set any attribute apart from parameters during init. Found attributes \['you_should_not_set_this_'\].r   estimator_namezPEstimator estimator_name should store all parameters as an attribute during initT)Zenable_metadata_routing)r	  )r   r!   r   r4   r   r   Zset_fit_request)r  r  r
  r   rC   rC   rD   $test_check_no_attributes_set_in_init  s4       r  c                  C   s(   t dd} t|  tdd} t|  d S )NZprecomputed)kernel)Zmetric)r   r-   r   )r  rC   rC   rD   test_check_estimator_pairwise  s    

r  c                	   C   s(   t tdd tdt  W 5 Q R X d S Nr   r   r  )r!   r   r'   r   rC   rC   rC   rD   'test_check_classifier_data_not_an_array!  s
     r  c                	   C   s(   t tdd tdt  W 5 Q R X d S r  )r!   r   r7   r   rC   rC   rC   rD   &test_check_regressor_data_not_an_array(  s
     r  c               	   C   sH   t  } d}tt|d t| jj|  W 5 Q R X t } t| jj|  d S )NzjEstimatorMissingDefaultTags._get_tags\(\) is missing entries for the following default tags: {'allow_nan'}r   )r   r!   r   r.   rn   r?   r   )r  err_msgrC   rC   rD   *test_check_estimator_get_tags_default_keys/  s    r  c               	   C   s|   d} t t| d tdt  W 5 Q R X tdt  t }t|jj| d|_d} t t| d t|jj| W 5 Q R X d S )Nz+Estimator does not have a feature_names_in_r   r  z;Docstring that does not document the estimator's attributeszNEstimator LogisticRegression does not document its feature_names_in_ attribute)	r!   rh   r+   rE   r   r   rn   r?   rB   )r  lrrC   rC   rD   -test_check_dataframe_column_names_consistency=  s    r  c                   @   s$   e Zd Zdd Zdd Zdd ZdS )_BaseMultiLabelClassifierMockc                 C   s
   || _ d S rF   response_output)rH   r  rC   rC   rD   rV   N  s    z&_BaseMultiLabelClassifierMock.__init__c                 C   s   | S rF   rC   rG   rC   rC   rD   rK   Q  s    z!_BaseMultiLabelClassifierMock.fitc                 C   s   ddiS )NZ
multilabelTrC   r   rC   rC   rD   r   T  s    z(_BaseMultiLabelClassifierMock._more_tagsN)r?   r@   rA   rV   rK   r   rC   rC   rC   rD   r  M  s   r  c            	   	   C   s   d\} }}t | d|ddddd\}}|| d  }G dd	 d	t}|| d
}d}tt|d t|jj| W 5 Q R X ||d d d df d
}d}tt|d t|jj| W 5 Q R X ||t	j
d
}d}tt|d t|jj| W 5 Q R X d S )Nd         r   r   2   Tr   	n_samplesZ
n_featuresZ	n_classesZn_labelslengthZallow_unlabeledZrandom_statec                   @   s   e Zd Zdd ZdS )z\test_check_classifiers_multilabel_output_format_predict.<locals>.MultiLabelClassifierPredictc                 S   s   | j S rF   r  rQ   rC   rC   rD   rR   f  s    zdtest_check_classifiers_multilabel_output_format_predict.<locals>.MultiLabelClassifierPredict.predictN)r?   r@   rA   rR   rC   rC   rC   rD   MultiLabelClassifierPredicte  s   r!  r  zdMultiLabelClassifierPredict.predict is expected to output a NumPy array. Got <class 'list'> instead.r   zbMultiLabelClassifierPredict.predict outputs a NumPy array of shape \(25, 4\) instead of \(25, 5\).zTMultiLabelClassifierPredict.predict does not output the same dtype than the targets.)r
   r  tolistr!   r   r)   rn   r?   ZastyperN   float64)	r  	test_size	n_outputsr   rJ   y_testr!  clfr  rC   rC   rD   7test_check_classifiers_multilabel_output_format_predictX  s6    

	r)  c            	   	      sd  d\} }}t | d|ddddd\}}|| d   G dd	 d	t}|t d
}d}tt|d t|jj| W 5 Q R X | 	 d
}d| d| d}tt
|d t|jj| W 5 Q R X  fddt|D }||d
}d}tt
|d t|jj| W 5 Q R X  fddt|D }||d
}d}tt
|d t|jj| W 5 Q R X  fddt|D }||d
}d}tt
|d t|jj| W 5 Q R X | d d d df d
}d}tt
|d t|jj| W 5 Q R X tj tjd}||d
}d}tt
|d t|jj| W 5 Q R X | d d
}d}tt
|d t|jj| W 5 Q R X d S )Nr  r   r   r  Tr   r  c                   @   s   e Zd Zdd ZdS )zgtest_check_classifiers_multilabel_output_format_predict_proba.<locals>.MultiLabelClassifierPredictProbac                 S   s   | j S rF   r  rQ   rC   rC   rD   predict_proba  s    zutest_check_classifiers_multilabel_output_format_predict_proba.<locals>.MultiLabelClassifierPredictProba.predict_probaN)r?   r@   rA   r*  rC   rC   rC   rD    MultiLabelClassifierPredictProba  s   r+  r  z|Unknown returned type .*csr_matrix.* by MultiLabelClassifierPredictProba.predict_proba. A list or a Numpy array is expected.r   zWhen MultiLabelClassifierPredictProba.predict_proba returns a list, the list should be of length n_outputs and contain NumPy arrays. Got length of z instead of .c                    s   g | ]}t  qS rC   )rN   Z	ones_liker   r   r'  rC   rD   r     s     zQtest_check_classifiers_multilabel_output_format_predict_proba.<locals>.<listcomp>zWhen MultiLabelClassifierPredictProba.predict_proba returns a list, this list should contain NumPy arrays of shape \(n_samples, 2\). Got NumPy arrays of shape \(25, 5\) instead of \(25, 2\).c                    s&   g | ]}t j jd  dft jdqS r   r   )rP   r   )rN   rO   rP   r   r-  r.  rC   rD   r     s    zwWhen MultiLabelClassifierPredictProba.predict_proba returns a list, it should contain NumPy arrays with floating dtype.c                    s&   g | ]}t j jd  dft jdqS r/  )rN   rO   rP   r$  r-  r.  rC   rD   r     s    zWhen MultiLabelClassifierPredictProba.predict_proba returns a list, each NumPy array should contain probabilities for each class and thus each row should sum to 1r"  zWhen MultiLabelClassifierPredictProba.predict_proba returns a NumPy array, the expected shape is \(n_samples, n_outputs\). Got \(25, 4\) instead of \(25, 5\).)r   znWhen MultiLabelClassifierPredictProba.predict_proba returns a NumPy array, the expected data type is floating.r   zWhen MultiLabelClassifierPredictProba.predict_proba returns a NumPy array, this array is expected to provide probabilities of the positive class and should therefore contain values between 0 and 1.)r
   r  r   r   r!   rh   r*   rn   r?   r#  r   rangerN   Z
zeros_liker   )	r  r%  r&  r   rJ   r+  r(  r  r  rC   r.  rD   =test_check_classifiers_multilabel_output_format_predict_proba  s    

	





r1  c            	   	   C   s   d\} }}t | d|ddddd\}}|| d  }G dd	 d	t}|| d
}d}tt|d t|jj| W 5 Q R X ||d d d df d
}d}tt|d t|jj| W 5 Q R X ||d
}d}tt|d t|jj| W 5 Q R X d S )Nr  r   r   r  Tr   r  c                   @   s   e Zd Zdd ZdS )zotest_check_classifiers_multilabel_output_format_decision_function.<locals>.MultiLabelClassifierDecisionFunctionc                 S   s   | j S rF   r  rQ   rC   rC   rD   r     s    ztest_check_classifiers_multilabel_output_format_decision_function.<locals>.MultiLabelClassifierDecisionFunction.decision_functionN)r?   r@   rA   r   rC   rC   rC   rD   $MultiLabelClassifierDecisionFunction  s   r2  r  zwMultiLabelClassifierDecisionFunction.decision_function is expected to output a NumPy array. Got <class 'list'> instead.r   r"  zMultiLabelClassifierDecisionFunction.decision_function is expected to provide a NumPy array of shape \(n_samples, n_outputs\). Got \(25, 4\) instead of \(25, 5\)z^MultiLabelClassifierDecisionFunction.decision_function is expected to output a floating dtype.)r
   r  r#  r!   r   r(   rn   r?   )	r  r%  r&  r   rJ   r'  r2  r(  r  rC   rC   rD   Atest_check_classifiers_multilabel_output_format_decision_function  sH    

	
r3  c                     sV   t jd   fddt D } dd | D }t }|| t }|| dS )z1Runs the tests in this file without using pytest.__main__c                    s    g | ]}| d rt |qS )Ztest_)
startswithgetattr)r   r   Zmain_modulerC   rD   r   6  s   
z,run_tests_without_pytest.<locals>.<listcomp>c                 S   s   g | ]}t |qS rC   )unittestZFunctionTestCase)r   fnrC   rC   rD   r   ;  s     N)sysmodulesdirr8  Z	TestSuiteZaddTestsZTextTestRunnerrun)Ztest_functionsZ
test_casesZsuiterunnerrC   r7  rD   run_tests_without_pytest3  s    


r?  c               	   C   s*   d} t t| d tdt W 5 Q R X d S )NzIClassifier estimator_name is not computing class_weight=balanced properlyr   r  )r!   r   r&   r   r  rC   rC   rD   2test_check_class_weight_balanced_linear_classifierB  s     r@  c               	   C   sF   t jdd} t }W 5 Q R X | r&t|D ]}|jjdr*tq*d S )NTr   r   )r   r   r   r   rn   r?   r5  )r   Z
estimatorsr  rC   rC   rD   test_all_estimators_all_publicK  s
    rA  r4  c               	   C   s<   t jdd} tt  W 5 Q R X tdd | D ks8td S )NTr   c                 S   s   g | ]
}|j qS rC   r   r   rC   rC   rD   r   a  s     z9test_xfail_ignored_in_check_estimator.<locals>.<listcomp>)r   r   r-   r   r   r   )r   rC   rC   rD   %test_xfail_ignored_in_check_estimator\  s    rB  c                  C   s*   t t t t g} | D ]}t| qd S rF   )r   r   r   r   r-   )Zminimal_estimatorsr  rC   rC   rD   (test_minimal_class_implementation_checksf  s    rC  c               	   C   s\   G dd dt } ttdd td| dd W 5 Q R X td| dd td| d	d d S )
Nc                   @   s2   e Zd ZdddZdd Zedd dd	 Zd
S )z1test_check_fit_check_is_fitted.<locals>.Estimator	attributec                 S   s
   || _ d S rF   behavior)rH   rF  rC   rC   rD   rV   r  s    z:test_check_fit_check_is_fitted.<locals>.Estimator.__init__c                 [   s&   | j dkrd| _n| j dkr"d| _| S )NrD  Tr   )rF  Z
is_fitted_
_is_fitted)rH   rI   rJ   rl   rC   rC   rD   rK   u  s
    

z5test_check_fit_check_is_fitted.<locals>.Estimator.fitc                 S   s
   | j dkS )N>   r   always-truerE  r   rC   rC   rD   <lambda>|      z:test_check_fit_check_is_fitted.<locals>.Estimator.<lambda>c                 S   s   | j dkrdS t| dS )NrH  TrG  )rF  r   r   rC   rC   rD   __sklearn_is_fitted__|  s    
zGtest_check_fit_check_is_fitted.<locals>.Estimator.__sklearn_is_fitted__N)rD  )r?   r@   rA   rV   rK   r:   rK  rC   rC   rC   rD   r   q  s   

r   z'passes check_is_fitted before being fitr   r  rH  rE  r   rD  )r   r!   	Exceptionr0   )r   rC   rC   rD   test_check_fit_check_is_fittedp  s
    rM  c               	   C   sJ   G dd dt } tjdd}td|   W 5 Q R X dd |D rFtd S )Nc                   @   s   e Zd Zdd ZdS )z-test_check_requires_y_none.<locals>.Estimatorc                 S   s   t ||\}}d S rF   )r=   rG   rC   rC   rD   rK     s    z1test_check_requires_y_none.<locals>.Estimator.fitNrb   rC   rC   rC   rD   r     s   r   Tr   r  c                 S   s   g | ]
}|j qS rC   )message)r   rv   rC   rC   rD   r     s     z.test_check_requires_y_none.<locals>.<listcomp>)r   r   r   r8   r   )r   r   rC   rC   rD   test_check_requires_y_none  s    rO  c                  C   sp   t ttfD ]`} tt|  }t|ks(tt|ks4tG dd d| }tt| }t|ks^tt|ks
tq
d S )Nc                   @   s   e Zd Zdd ZdS )z>test_non_deterministic_estimator_skip_tests.<locals>.Estimatorc                 S   s   ddiS )NZnon_deterministicTrC   r   rC   rC   rD   r     s    zItest_non_deterministic_estimator_skip_tests.<locals>.Estimator._more_tagsNr   rC   rC   rC   rD   r     s   r   )r   r   r   listr$   r2   r   r3   )r  Z	all_testsr   rC   rC   rD   +test_non_deterministic_estimator_skip_tests  s    rQ  c               
   C   s   G dd dt t} |  }t|jj|dks.tG dd d| }| }d}tt|d t|jj| W 5 Q R X ttdd	d
dg|j	d< | }t|jj| tt
ddd
dttddd
dttddd
dttdd	ddg}d}|D ]<}|g|j	d< | }tt|d t|jj| W 5 Q R X qdS )zHCheck the test for the contamination parameter in the outlier detectors.c                   @   s.   e Zd ZdZd
ddZdddZddd	ZdS )zJtest_check_outlier_contamination.<locals>.OutlierDetectorWithoutConstraintz.Outlier detector without parameter validation.皙?c                 S   s
   || _ d S rF   )contamination)rH   rS  rC   rC   rD   rV     s    zStest_check_outlier_contamination.<locals>.OutlierDetectorWithoutConstraint.__init__Nc                 S   s   | S rF   rC   )rH   rI   rJ   r   rC   rC   rD   rK     s    zNtest_check_outlier_contamination.<locals>.OutlierDetectorWithoutConstraint.fitc                 S   s   t |jd S rL   rM   rG   rC   rC   rD   rR     s    zRtest_check_outlier_contamination.<locals>.OutlierDetectorWithoutConstraint.predict)rR  )NN)Nr   rC   rC   rC   rD    OutlierDetectorWithoutConstraint  s   

rT  Nc                   @   s   e Zd ZdedhgiZdS )zGtest_check_outlier_contamination.<locals>.OutlierDetectorWithConstraintrS  autoN)r?   r@   rA   r   _parameter_constraintsrC   rC   rC   rD   OutlierDetectorWithConstraint  s   rW  zDcontamination constraints should contain a Real Interval constraint.r   r   g      ?right)closedrS  r`   r"  r   leftz<contamination constraint should be an interval in \(0, 0.5\])r   r   r5   rn   r?   r   r!   r   r   rV  r   )rT  detectorrW  r  Zincorrect_intervalsintervalrC   rC   rD    test_check_outlier_contamination  s0    

r]  c                  C   s   t dd} td|  dS )zCheck that in case with some probabilities ties, we relax the
    ranking comparison with the decision function.
    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/24025
    Zlog_loss)Zlossr   N)r   r,   )r  rC   rC   rD   test_decision_proba_tie_ranking  s    
r^  )r   r:  r8  r   Znumbersr   r   r   ZnumpyrN   Zscipy.sparsesparser   Zsklearnr   r   Zsklearn.baser   r   r   Zsklearn.clusterr	   r   r
   Zsklearn.decompositionr   Zsklearn.ensembler   Zsklearn.exceptionsr   r   Zsklearn.linear_modelr   r   r   r   Zsklearn.mixturer   Zsklearn.neighborsr   Zsklearn.svmr   r   r   r   r   r   Zsklearn.utils._param_validationr   r   Zsklearn.utils._testingr   r   r   r   r    r!   Zsklearn.utils.estimator_checksr"   r#   r$   r%   r&   r'   r(   r)   r*   r+   r,   r-   r.   r/   r0   r1   r2   r3   r4   r5   r6   r7   r8   r9   Zsklearn.utils.metaestimatorsr:   Zsklearn.utils.validationr;   r<   r=   rh   r>   rE   rT   rY   r^   ra   rc   rq   rs   rw   rx   r|   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r  r  r  r  r  r)  r1  r3  r?  r@  rA  r?   rB  rC  rM  rO  rQ  r]  r^  rC   rC   rC   rD   <module>   s    h

	
%	 
".+{5	


6