U
    }hz                     @   s  d Z ddlZddlZddlZddlZddlmZ ddlm	Z	m
Z
 ddlmZmZmZ ddlmZmZmZmZmZmZmZmZmZmZmZmZmZ ddlmZ ddlm Z m!Z!m"Z"m#Z#m$Z$ d	d
 Z%dd Z&dd Z'ej()dddgdd Z*dd Z+dd Z,dd Z-ej()de.ddgddgddgddgge.ddddge.dd gdfe.d!dgd"dgddgd#dgge.ddddge.ddgdfe.ddgddgddgddgge.ddddge.ej/d gdfe.d!dgd"dgddgd#dgge.ddddge.ej/ej/gdfgd$d% Z0ej()d&e.ddgddgddgddgge.ddddge.dd'ge.d(d)gdfe.d!dgd"dgddgd#dgge.ddddge.ddge.d(d(gdfe.ddgddgddgd"dgge.dddd"ge.e1ej2j3d*ge.dd+gdfe.d"dgddgddgddgge.dddd"ge.e1ej2j3d*ge.dd+gdfe.ddgddgddgddgge.ddddge.ej/d'ge.ej/d)gdfe.d!dgd"dgddgd#dgge.ddddge.ej/ej/ge.ej/ej/gdfe.ddgddgddgd"dgge.dddd"ge.ej4d*ge.dd+gdfe.d"dgddgddgddgge.dddd"ge.ej4d*ge.dd+gdfgd,d- Z5d.d/ Z6d0d1 Z7d2d3 Z8d4d5 Z9d6d7 Z:ej()d8ej;ej2gd9d: Z<d;d< Z=d=d> Z>d?d@ Z?dAdB Z@dCdD ZAdEdF ZBdGdH ZCej()dIdJdKdLgej()dMdd!dgdNdO ZDdPdQ ZEdRdS ZFdTdU ZGdVdW ZHdXdY ZIdZd[ ZJd\d] ZKd^d_ ZLd`da ZMdbdc ZNddde ZOdfdg ZPdhdi ZQdS )jz0
Todo: cross-check the F-value with stats model
    N)assert_allclose)sparsestats)	load_irismake_classificationmake_regression)GenericUnivariateSelect	SelectFdr	SelectFpr	SelectFweSelectKBestSelectPercentilechi2	f_classiff_onewayf_regressionmutual_info_classifmutual_info_regressionr_regression)	safe_mask)_convert_containerassert_almost_equalassert_array_almost_equalassert_array_equalignore_warningsc                  C   sj   t jd} | dd}d| dd }t||\}}t||\}}t ||sVtt ||sftd S )Nr   
         )nprandomRandomStateZrandnr   r   ZallcloseAssertionError)rngX1X2fpvf2pv2 r)   W/tmp/pip-unpacked-wheel-ig1s1lm8/sklearn/feature_selection/tests/test_feature_select.pytest_f_oneway_vs_scipy_stats)   s    r+   c                  C   sf   t jd} | jddd}t d}t||\}}t|t|\}}t||dd t||dd d S )Nr   r   )r   r   size   decimal)	r   r   r    randintaranger   astypefloatr   )r"   XyZfintZpintr%   pr)   r)   r*   test_f_oneway_ints4   s    
r8   c                  C   s   t ddddddddd	d
dd\} }t| |\}}tt| |\}}|dk sTt|dk sdt|dk  stt|d d dk  st|dd  dk stt|| t|| d S N      r      r      r           r   F	n_samples
n_featuresn_informativeZn_redundantZ
n_repeatedZ	n_classesZn_clusters_per_classZflip_yZ	class_sepshufflerandom_state   皙?-C6?)r   r   r   
csr_matrixallr!   r   r5   r6   Fr&   ZF_sparseZ	pv_sparser)   r)   r*   test_f_classifB   s,    

rL   centerTFc           	      C   s   t dddddd\}}t||| d}d|k  s4t|d	k  sDtt|d
}t||| d}t|| t||d d tjf f}tj	|dd}|d ddf }t
||dd d S )Ni  r;   rE   Fr   r@   rA   rB   rC   rD   rM   r   r   )Zrowvarr   r/   )r   r   rI   r!   r   r   r   ZhstackZnewaxisZcorrcoefr   )	rM   r5   r6   Zcorr_coeffsZsparse_XZsparse_corr_coeffsZZcorrelation_matrixZnp_corr_coeffsr)   r)   r*   test_r_regression^   s"        


rR   c                  C   s  t dddddd\} }t| |\}}|dk s4t|dk sDt|dk  sTt|d d dk  slt|dd  d	k stt| |d
d\}}tt| |d
d\}}t|| t|| t| |dd\}}tt| |dd\}}t|| t|| d S )Nr:   r;   rE   Fr   rN   r   rF   rG   TrO   )r   r   rI   r!   r   rH   r   rJ   r)   r)   r*   test_f_regressiont   s*        



rS   c                  C   sf   t jd} | dd}t dt}t||\}}t||t\}}t	||d t	||d d S )Nr   r   r;   rE   )
r   r   r    randr2   r3   intr   r4   r   )r"   r5   r6   F1Zpv1F2r(   r)   r)   r*   test_f_regression_input_dtype   s    rX   c                  C   s   t dddd} | j}t |}|d d d  d9  < d|d< t| |d	d
\}}t| |dd
\}}t||d  |d  | t|d d d S )N   rP   r   r<   g      r>   r   TrO   F      ?       @g@9w?)r   r2   Zreshaper-   onesr   r   r   )r5   r@   YrV   _rW   r)   r)   r*   test_f_regression_center   s    
r`   z&X, y, expected_corr_coef, force_finiter<   r   r   r.   r>   gI+?rE   r   r=   c              	   C   s@   t    t dt t| ||d}W 5 Q R X tj|| dS )zCheck the behaviour of `force_finite` for some corner cases with `r_regression`.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/15672
    errorforce_finiteN)warningscatch_warningssimplefilterRuntimeWarningr   r   testingr   )r5   r6   Zexpected_corr_coefrc   Z	corr_coefr)   r)   r*   test_r_regression_force_finite   s    '
ri   z;X, y, expected_f_statistic, expected_p_values, force_finiteg
[?r[   gSr.j?g?gajK?c              	   C   sR   t  $ t dt t| ||d\}}W 5 Q R X tj|| tj|| dS )zCheck the behaviour of `force_finite` for some corner cases with `f_regression`.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/15672
    ra   rb   N)rd   re   rf   rg   r   r   rh   r   )r5   r6   Zexpected_f_statisticZexpected_p_valuesrc   Zf_statisticZp_valuesr)   r)   r*   test_f_regression_corner_case   s
    M
rj   c                  C   s   t ddddddddd	d
dd\} }t| |\}}|dk s@t|dk sPt|dk  s`t|d d dk  sxt|dd  dk std S r9   )r   r   rI   r!   )r5   r6   rK   r&   r)   r)   r*   test_f_classif_multi_class.  s&    
rk   c                  C   s   t ddddddddd	d
dd\} }ttdd}|| || }ttddd| || }t|| | }t	d}d|d d< t|| d S Nr:   r;   r   r<   r   r=   r   r>   r   Fr?      
percentilero   modeparamrE   )
r   r   r   fit	transformr   r   get_supportr   zerosr5   r6   univariate_filterX_rX_r2supportgtruthr)   r)   r*   test_select_percentile_classifG  s6    
 

r}   c            	      C   s
  t ddddddddd	d
dd\} }t| } ttdd}|| || }ttddd| || }t|	 |	  |
 }td}d|d d< t|| ||}t|stt||}|j| jkstt|d d |f 	 |	  | | kstd S rl   )r   r   rH   r   r   rs   rt   r   r   Ztoarrayru   r   rv   inverse_transformissparser!   r   shapeZgetnnz)	r5   r6   rx   ry   rz   r{   r|   ZX_r2invZsupport_maskr)   r)   r*   %test_select_percentile_classif_sparseg  sD    

 



r   c                  C   s   t ddddddddd	d
dd\} }ttdd}|| || }ttddd| || }t|| | }t	d}d|d d< t|| d S )Nr:   r;   r   r<   r   r=   r   r>   r   Fr?   rE   kk_bestrp   )
r   r   r   rs   rt   r   r   ru   r   rv   rw   r)   r)   r*   test_select_kbest_classif  s6    
 

r   c                  C   sf   t ddddd\} }ttdd}|| || }t| | ttddd	| || }t|| d S )
Nr;   r   Fr   r@   rA   rC   rD   rI   r   r   rp   )r   r   r   rs   rt   r   r   )r5   r6   rx   ry   rz   r)   r)   r*   test_select_kbest_all  s"       

 r   dtype_inc              	   C   s   t ddddd\}}|| }ttdd}||| | }tjdtd}t	|| t
jtdd	 ||}W 5 Q R X |jd
kst|j| kstd S )Nr;   r   Fr   r   r   dtypeNo features were selectedmatch)r;   r   )r   r3   r   r   rs   ru   r   rv   boolr   pytestwarnsUserWarningrt   r   r!   r   )r   r5   r6   rx   r{   r|   
X_selectedr)   r)   r*   test_select_kbest_zero  s        


r   c                  C   s   t ddddddddd	d
dd\} }ttdd}|| || }td}d|d d< dD ]<}tt|dd| || }t|| |	 }t
|| qZd S )Nr:   r;   r   r<   r   r=   r   r>   r   Fr?   {Gz?alpharE   fdrZfprfwerp   )r   r   r   rs   rt   r   rv   r   r   ru   r   r5   r6   rx   ry   r|   rq   rz   r{   r)   r)   r*   test_select_heuristics_classif  s8    

 
r   c                 C   s:   | j }|  }tt|| t||  d   d S )N)Zscores_ru   r   r   sortsum)Zscore_filterscoresr{   r)   r)   r*   assert_best_scores_kept  s    r   c                  C   s   t dddddd\} }ttdd}|| || }t| ttd	dd
| || }t|| | }t	
d}d|d d< t|| |  }d|d d t	|f< t||| t|t||t d S )Nr:   r;   rE   Fr   rN   rm   rn   ro   rp   r   )r   r   r   rs   rt   r   r   r   ru   r   rv   copyZlogical_notr~   r3   r   )r5   r6   rx   ry   rz   r{   r|   ZX_2r)   r)   r*   !test_select_percentile_regression  s:        
 


 r   c                  C   s   t dddddd\} }ttdd}|| || }t| ttd	dd
| || }t|| | }t	
d}t|| d S )Nr:   r;   rE   Fr   rN   d   rn   ro   rp   )r   r   r   rs   rt   r   r   r   ru   r   r]   rw   r)   r)   r*   &test_select_percentile_regression_full"  s*        
 

r   c                  C   s   t ddddddd\} }ttdd}|| || }t| ttd	dd
| || }t|| | }t	
d}d|d d< t|| d S )Nr:   r;   rE   Fr   r   r@   rA   rB   rC   rD   Znoiser   r   rp   r   )r   r   r   rs   rt   r   r   r   ru   r   rv   rw   r)   r)   r*   test_select_kbest_regression7  s.    
	 

r   c                  C   s   t ddddddd\} }ttdd	}|| || }td}d
|d d< dD ]l}tt|dd| || }t|| |	 }t|d d tj
dtd t|dd  d
kdk sPtqPd S )Nr:   r;   rE   Fr   r   r   r   r   r   r   rp   rE   r   r   )r   r
   r   rs   rt   r   rv   r   r   ru   r]   r   r   r!   r   r)   r)   r*   !test_select_heuristics_regressionS  s0    
	
 
r   c                  C   sp  t ddgddgddgg} t dgdgdgg}t| |\}}t|t ddg t|t dd	g ttd
d}|| | | }t|t ddg ttdd}|| | | }t|t ddg t	tdd}|| | | }	t|	t ddg t
td
d}
|
| | |
 }t|t ddg ttd
d}|| | | }t|t ddg d S )Nr   r;      r   r   g      @ggm?gQaK?gX٬<y?皙?r   TFr   2   rn   )r   arrayr   r   r	   rs   ru   r   r   r   r
   r   )r5   r6   r   ZpvaluesZ
filter_fdrZsupport_fdrZfilter_kbestZsupport_kbestZfilter_percentileZsupport_percentileZ
filter_fprZsupport_fprZ
filter_fweZsupport_fwer)   r)   r*   test_boundary_case_ch2p  s2    r   r   gMbP?r   r   rB   c                    sP   dd t  fddtdD } |ks4t|dkrL| d ksLtd S )Nc              	   S   s   t dd|d|dd\}}tjdd@ tt| d}||||}ttd	| d
|||}W 5 Q R X t|| |	 }t
||d  dk}	t
|d | dk}
|	dkrdS |	|
|	  }|S )N   r;   Fr   r   T)recordr   r   rp   r   r   r>   )r   rd   re   r	   r   rs   rt   r   r   ru   r   r   )r   rB   rD   r5   r6   rx   ry   rz   r{   Znum_false_positivesZnum_true_positivesfalse_discovery_rater)   r)   r*   
single_fdr  s8    
	 
z.test_select_fdr_regression.<locals>.single_fdrc                    s   g | ]} |qS r)   r)   ).0rD   r   rB   r   r)   r*   
<listcomp>  s     z.test_select_fdr_regression.<locals>.<listcomp>r   r   r   )r   Zmeanranger!   )r   rB   r   r)   r   r*   test_select_fdr_regression  s    $r   c                  C   s   t dddddd\} }ttdd}|| || }ttd	dd
| || }t|| | }t	d}d|d d< t|d d tj
dtd t|dd  dkdk std S )Nr:   r;   rE   Fr   rN   r   r   r   rp   r   r   r   r<   )r   r   r   rs   rt   r   r   ru   r   rv   r]   r   r   r!   rw   r)   r)   r*   test_select_fwe_regression  s,        
 

r   c                  C   s   dddgdddgdddgdddgg} dg}dd }| D ]t}t |dd}t|j|g|}|jd dksjtt| t |dd}t|j|g|}|jd dkstt| q6d S )Nr   r   c                 S   s   | d | d fS Nr   r)   r5   r6   r)   r)   r*   <lambda>      z.test_selectkbest_tiebreaking.<locals>.<lambda>r   r<   )r   r   fit_transformr   r!   r   ZXsr6   Zdummy_scorer5   selr#   r$   r)   r)   r*   test_selectkbest_tiebreaking  s    $r   c                  C   s   dddgdddgdddgdddgg} dg}dd }| D ]t}t |dd}t|j|g|}|jd dksjtt| t |dd}t|j|g|}|jd dkstt| q6d S )	Nr   r   c                 S   s   | d | d fS r   r)   r   r)   r)   r*   r     r   z3test_selectpercentile_tiebreaking.<locals>.<lambda>"   rn   C   r<   )r   r   r   r   r!   r   r   r)   r)   r*   !test_selectpercentile_tiebreaking  s    $r   c                  C   s   t dddgdddgg} ddg}tdD ]p}| d d |f }ttdd||}|jd	ksbtd|ksntt	td
d||}|jd	kstd|ks,tq,d S )N'  '  i'  r   r   )r   r   r<   r<   r   )r<   r<   r   rn   )
r   r   	itertoolspermutationsr   r   r   r   r!   r   )ZX0r6   permr5   Xtr)   r)   r*   test_tied_pvalues  s    r   c                  C   s   t dddgdddgdddgg} ddgddgddgg}ttdd	| |}|jd
ksZtd|ksftttdd| |}|jd
kstd|kstd S )Nr   r   r   r   i  c   r   r<   r   )r   r<   r   rn   )r   r   r   r   r   r   r!   r   )r5   r6   r   r)   r)   r*   test_scorefunc_multilabel  s    "r   c                  C   st   t dddgdddgg} ddg}dD ]H}tt|d| |}|dddgg}t|d t d| d   q&d S )Nr   r   )r   r<   r   r   r<   r   )r   r   r   r   rs   rt   r   r2   )ZX_trainZy_trainrA   r   ZX_testr)   r)   r*   test_tied_scores   s    r   c                  C   st   dddgdddgdddgg} dddg}t tddttddfD ]0}t|j| | t|jd	d
tddg q>d S )Nr   r   rP         ?r<   r   r   rn   T)indices)	r   r   r   r   rs   r   ru   r   r   )r5   r6   selectr)   r)   r*   	test_nans+  s    


r   c               	   C   s|   dddgdddgdddgg} dddg}t t tdd| | W 5 Q R X t t tddd| | W 5 Q R X d S )	Nr   r   rP   r   r.   r   r   rp   )r   Zraises
ValueErrorr   rs   r   r   r)   r)   r*   test_invalid_k:  s    
r   c               	   C   sD   t ddd\} }d| d d df< tt t| | W 5 Q R X d S )Nr   rE   )r@   rA   r\   r   )r   r   r   r   r   r   r)   r)   r*   test_f_classif_constant_featureD  s    r   c               
   C   s   t jd} | dd}| jdddd}tdd||tdd||tdd||t	dd||t
dd	||g}|D ]J}t| t d tjtd
d ||}W 5 Q R X |jdkstqd S )Nr   (   r   r.   r,   r   r   rn   r   r   r   )r   r   )r   r   r    rT   r1   r   rs   r	   r
   r   r   r   ru   rv   r   r   r   rt   r   r!   )r"   r5   r6   Zstrict_selectorsselectorr   r)   r)   r*   test_no_feature_selectedM  s    r   c                  C   s   t dddddddddddd	\} }ttdd
}|| || }ttddd| || }t|| | }t	d}d|d d< t|| t
tdd}|| || }ttddd| || }t|| | }t	d}d|d d< t|| d S )Nr   rE   r   r   r<   r>   r   Fr?   r   r   rp   r   rn   ro   )r   r   r   rs   rt   r   r   ru   r   rv   r   rw   r)   r)   r*   test_mutual_info_classifb  sR    
 


 

r   c                  C   s   t ddddddd\} }ttdd}|| || }t| ttddd	| || }t|| | }t	
d}d
|d d< t|| ttdd}|| || }ttddd	| || }t|| | }t	
d}d
|d d< t|| d S )Nr   r   r<   Fr   r   r   r   rp   r   r;   rn   ro   )r   r   r   rs   rt   r   r   r   ru   r   rv   r   rw   r)   r)   r*   test_mutual_info_regression  sJ    

 


 

r   c                     s   t d} tddd\}}|tjtjd}| j|d dd|d< |j  fd	d
}t	|ddj
dd}|||}t|jdddg |j D ]\}}||j| kstqdS )zmCheck that the output datafarme dtypes are the same as the input.

    Non-regression test for gh-24860.
    ZpandasT)Z
return_X_yZas_frame)petal length (cm)petal width (cm)r   r   )Zbinspetal_width_binnedc                    s(   dddddd t  fddD S )	Nr   r<   r   r.   rE   )zsepal length (cm)zsepal width (cm)r   r   r   c                    s   g | ]} | qS r)   r)   )r   nameZrankingr)   r*   r     s     zBtest_dataframe_output_dtypes.<locals>.selector.<locals>.<listcomp>)r   Zasarrayr   Zcolumn_orderr   r*   r     s    z.test_dataframe_output_dtypes.<locals>.selectorr   r   )rt   r   N)r   Zimportorskipr   r3   r   float32float64Zcutcolumnsr   Z
set_outputr   r   Zdtypesitemsr!   )pdr5   r6   r   rx   outputr   r   r)   r   r*   test_dataframe_output_dtypes  s$    

 r   )R__doc__r   rd   Znumpyr   r   Znumpy.testingr   Zscipyr   r   Zsklearn.datasetsr   r   r   Zsklearn.feature_selectionr   r	   r
   r   r   r   r   r   r   r   r   r   r   Zsklearn.utilsr   Zsklearn.utils._testingr   r   r   r   r   r+   r8   rL   markZparametrizerR   rS   rX   r`   r   nanri   Zfinfor   maxinfrj   rk   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r)   r)   r)   r*   <module>   s   <
 	 	 	 !
 
 
 
 
 
 
 
 E
 , 
%"1
	,(