################################################################################
# Util functions for running shift location benchmarking methods.
################################################################################

import sys
import numpy as np
import time
import random
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import (
    SelectKBest, 
    f_classif, 
    chi2, 
    mutual_info_classif
)
from sklearn.metrics import f1_score

from .FeatureShift.fsd.featureshiftdetector import FeatureShiftDetector
from .FeatureShift.fsd.divergence import FisherDivergence, ModelKS, KnnKS
from .FeatureShift.fsd.models import Knn, GaussianDensity, DeepDensity

from .ScikitFeature.skfeature.function.information_theoretical_based.MRMR import mrmr
from .ScikitFeature.skfeature.function.information_theoretical_based.CMIM import cmim
from .ScikitFeature.skfeature.function.information_theoretical_based.CIFE import cife
from .ScikitFeature.skfeature.function.information_theoretical_based.DISR import disr
from .ScikitFeature.skfeature.function.information_theoretical_based.ICAP import icap
from .ScikitFeature.skfeature.function.information_theoretical_based.JMI import jmi
from .ScikitFeature.skfeature.function.information_theoretical_based.MIFS import mifs
from .ScikitFeature.skfeature.function.information_theoretical_based.MIM import mim

from .FastCMIM.fast_cmim import fast_cmim

sys.path.append("../")
from src.preprocessing._data_manipulation import _create_source_labels

def filter_selectKbest(reference, query, y_true, config):
    """
    Apply ``SelectKBest`` from sklearn to identify corrupted features 
    between a reference and a query dataset. All features with a 
    p_value < significance_level are classified as corrupted.
    
    Parameters
    ----------
    reference : DataFrame
        Reference dataset. 
    query : DataFrame
        Query dataset.
    y_true : array-like
        True filtering labels where original features are assigned label 0, 
        and manipulated features are assigned label 1.
    config : dict
        significance_level : float
            Features with p_value < significance_level are considered corrupted.
        score_func : ['f_classif', 'chi2'].
            If 'f_classif', ANOVA-F is used (features are continuous).
            If 'chi2', Chi-Square is used (features are categorical).
    
    Returns
    -------
    output_df : dict
        n_corrupted_features_ : int
            Total number of features detected as being corrupted.
        mask_ : list of length (n_features_in_)
            The mask of corrupted features, where 1 indicates a variable is 
            corrupted and 0 otherwise.
        F1 Score : float
            F1 Score in the filtering task.
        Runtime : float
            Total runtime.
    """
    # Obtain the concatenation of all samples from the reference and query 
    # datasets, and an array of labels indicating the source of each sample 
    # (0 for reference, 1 for query)
    X, y = _create_source_labels(reference, query)
    
    # Measure time
    start_time = time.time()
    
    # Fit ``SelectKBest`` from sklearn
    if config['score_func'] == "f_classif":
        select = SelectKBest(score_func=f_classif, k='all').fit(X, y)
    elif config['score_func'] == "chi2":
        select = SelectKBest(score_func=chi2, k='all').fit(X, y)
    
    # Obtain p_values
    p_values = select.pvalues_
    
    # Obtain indexes of the features with p_value < significance_level
    manipulated_idxs = p_values < config['significance_level']
    
    # Measure time
    end_time = time.time()
    
    # Define empty dataframe to store output metrics
    output_df = {}
    
    # Store mask with 1 = corrupted, 0 = not corrupted
    output_df['mask_'] = np.zeros(reference.shape[1]).astype(int)
    output_df['mask_'][manipulated_idxs] = 1
    
    # Store total number of corrupted features detected as being corrupted
    output_df['n_corrupted_features_'] = np.sum(output_df['mask_'])
    
    # Store F1 Score in the filtering task and total runtime
    output_df['F1 Score'] = f1_score(y_true, output_df['mask_'], zero_division=1)
    output_df['Runtime'] = end_time - start_time
    
    return output_df


def filter_mutual_information(reference, query, y_true, config):
    """
    Apply ``mutual_info_classif`` from sklearn to identify corrupted features 
    between a reference and a query dataset. All features with a 
    MI < threshold are classified as corrupted.
    
    Parameters
    ----------
    reference : DataFrame
        Reference dataset. 
    query : DataFrame
        Query dataset.
    y_true : array-like
        True filtering labels where original features are assigned label 0, 
        and manipulated features are assigned label 1.
    config : dict
        threshold : float
            Features with MI < threshold are considered corrupted.
        random_state: int
            Controls randomness by passing an integer for reproducible output.
    
    Returns
    -------
    output_df : dict
        n_corrupted_features_ : int
            Total number of features detected as being corrupted.
        mask_ : list of length (n_features_in_)
            The mask of corrupted features, where 1 indicates a variable is 
            corrupted and 0 otherwise.
        F1 Score : float
            F1 Score in the filtering task.
        Runtime : float
            Total runtime.
    """
    # Obtain the concatenation of all samples from the reference and query 
    # datasets, and an array of labels indicating the source of each sample 
    # (0 for reference, 1 for query)
    X, y = _create_source_labels(reference, query)
    
    # Measure time
    start_time = time.time()
    
    # Fit ``mutual_info_classif`` from sklearn
    MI = mutual_info_classif(X, y, random_state=config['random_state'])
    
    # Obtain indexes of the features with MI > threshold
    manipulated_idxs = MI > config['threshold']
    
    # Measure time
    end_time = time.time()
    
    # Define empty dataframe to store output metrics
    output_df = {}
    
    # Store mask with 1 = corrupted, 0 = not corrupted
    output_df['mask_'] = np.zeros(reference.shape[1]).astype(int)
    output_df['mask_'][manipulated_idxs] = 1
    
    # Store total number of corrupted features detected as being corrupted
    output_df['n_corrupted_features_'] = np.sum(output_df['mask_'])
    
    # Store F1 Score in the filtering task and total runtime
    output_df['F1 Score'] = f1_score(y_true, output_df['mask_'], zero_division=1)
    output_df['Runtime'] = end_time - start_time
    
    return output_df


def filter_feature_shift_detection(reference, query, y_true, config): 
    """
    Apply ``MB-SM``, ``MB-KS``, ``KNN-KS``, or ``DD-SM`` feature shift 
    detection techniques to identify corrupted features between a reference and 
    a query dataset.
    
    Parameters
    ----------
    reference : DataFrame
        Reference dataset. 
    query : DataFrame
        Query dataset.
    y_true : array-like
        True filtering labels where original features are assigned label 0, 
        and manipulated features are assigned label 1.
    config : dict
        partition : str
            Used to define how to define X_boost, Y_boost, X, and Y.
        method : ['MB-SM', 'MB-KS', 'KNN-KS', 'DD-SM']
            Model and statistic.
        n_selected_features : int
            The fixed budget of features which can be checked if a shift is detected. 
            (i.e. the number of features suspected to have been compromised).
        n_expectation : int
            The number of samples used in estimate the expectation of the divergence 
            of p(x) and q(x).
        n_neighbors : int
            The number of neighbors to consider and return for each neighborhood sample.
        n_bootstrap_runs : int
            The number of bootstrap runs to perform when bootstrapping 
            (i.e. {X_boot, Y_boot}).
        random_state: int
            Controls randomness by passing an integer for reproducible output.
    
    Returns
    -------
    output_df : dict
        n_corrupted_features_ : int
            Total number of features detected as being corrupted.
        mask_ : list of length (n_features_in_)
            The mask of corrupted features, where 1 indicates a variable is 
            corrupted and 0 otherwise.
        scores_ : array-like
            Scores assigned to each feature.
        F1 Score : float
            F1 Score in the filtering task.
        Runtime : float
            Total runtime.
    """
    random.seed(config['random_state'])
    
    # Measure time
    start_time = time.time()
    
    if config['method'] == 'MB-SM':
        model = GaussianDensity()
        statistic = FisherDivergence(model, n_expectation=config['n_expectation'])
    elif config['method'] == 'MB-KS':
        model = GaussianDensity()
        statistic = ModelKS(model, n_expectation=config['n_expectation'])
    elif config['method'] == 'KNN-KS':
        model = Knn(n_neighbors=config['n_neighbors'])
        statistic = KnnKS(model, n_expectation=config['n_expectation'])
    elif config['method'] == 'DD-SM':
        model = DeepDensity()
        statistic = FisherDivergence(model, n_expectation=config['n_expectation'])
    else:
        raise ValueError(f'{config["method"]} not supported.')
    
    # Create FeatureShiftDetector object
    fsd = FeatureShiftDetector(statistic=statistic, bootstrap_method='simple', 
          n_bootstrap_samples=config['n_bootstrap_runs'], n_compromised=config['n_selected_features'])
    
    if config['partition'] == 'X_boost=X=reference, Y_boost=Y=query':
        # Sets the detection threshold
        fsd.fit(X_boot=reference, Y_boot=query, random_state=config['random_state'])
        _, manipulated_idxs, scores = fsd.detect_and_localize(
            X=reference, Y=query, random_state=config['random_state'], return_scores=True
        )
    
    elif config['partition'] == 'X_boot=reference(50%), Y_boot=reference(50%), X=reference, Y=query':
        # Sets the detection threshold
        X_boot, Y_boot = train_test_split(reference, random_state=config['random_state'], test_size=0.5)
        fsd.fit(X_boot=X_boot, Y_boot=Y_boot, random_state=config['random_state'])
        _, manipulated_idxs, scores = fsd.detect_and_localize(
            X=reference, Y=query, random_state=config['random_state'], return_scores=True
        )
    
    if isinstance(manipulated_idxs, type(None)):
        manipulated_idxs = []
    
    # Measure time
    end_time = time.time()
    
    # Define empty dataframe to store output metrics
    output_df = {}
    
    # Store scores
    output_df['scores_'] = scores
    
    # Store mask with 1 = corrupted, 0 = not corrupted
    output_df['mask_'] = np.zeros(reference.shape[1]).astype(int)
    output_df['mask_'][manipulated_idxs] = 1
    
    # Store total number of corrupted features detected as being corrupted
    output_df['n_corrupted_features_'] = np.sum(output_df['mask_'])
    
    # Store F1 Score in the filtering task and total runtime
    output_df['F1 Score'] = f1_score(y_true, output_df['mask_'], zero_division=1)
    output_df['Runtime'] = end_time - start_time
    
    return output_df


def filter_scikit_feature(reference, query, y_true, config):
    """
    Apply ``MRMR``, ``CMIM``, ``CIFE``, ``DISR``, ``ICAP``, ``JMI``, ``MIFS``, 
    ``MIM``, ``FAST-CMIM`` feature shift detection techniques to identify 
    corrupted features between a reference and a query dataset.
    
    Parameters
    ----------
    reference : DataFrame
        Reference dataset. 
    query : DataFrame
        Query dataset.
    y_true : array-like
        True filtering labels where original features are assigned label 0, 
        and manipulated features are assigned label 1.
    config : dict
        method : ['MRMR', 'CMIM', 'CIFE', 'DISR', 'ICAP', 'JMI', 'MIFS', 'MIM', 
                'FAST-CMIM']
            Model.
        n_selected_features : int or None
            The fixed budget of features which is compromised.
        is_n_selected_features_specified : bool
            True if n_selected_features is specified, False otherwise.
    
    Returns
    -------
    output_df : dict
        n_corrupted_features_ : int
            Total number of features detected as being corrupted.
        mask_ : list of length (n_features_in_)
            The mask of corrupted features, where 1 indicates a variable is 
            corrupted and 0 otherwise.
        J_CMIM_ : array-like of shape (n_features,)
            Corresponding objective function value of selected features
        MIfy_ : array-like of shape (n_features,)
            corresponding mutual information between selected features and response 
        F1 Score : float
            F1 Score in the filtering task.
        Runtime : float
            Total runtime.
    """
    # Obtain the concatenation of all samples from the reference and query 
    # datasets, and an array of labels indicating the source of each sample 
    # (0 for reference, 1 for query)
    X, y = _create_source_labels(reference, query)
    
    # Measure time
    start_time = time.time()
    
    if config['is_n_selected_features_specified'] and config['n_selected_features'] == 0:
        # If is_n_selected_features_specified is set and there are no corrupted features,
        # any feature will be detected as being corrupted
        manipulated_idxs, J_CMI, MIfy = [], [], []
    else:
        params = {}
        if config['is_n_selected_features_specified']:
            # If is_n_selected_features_specified is set, provide the number of 
            # selected features to the models
            params['n_selected_features'] = config['n_selected_features']
        
        if config['method'] == 'MRMR':
            manipulated_idxs, J_CMI, MIfy = mrmr(np.array(X), y, **params)
        elif config['method'] == 'CMIM':
            manipulated_idxs, J_CMI, MIfy = cmim(np.array(X), y, **params)
        elif config['method'] == 'CIFE':
            manipulated_idxs, J_CMI, MIfy = cife(np.array(X), y, **params)
        elif config['method'] == 'DISR':
            manipulated_idxs, J_CMI, MIfy = disr(np.array(X), y, **params)
        elif config['method'] == 'ICAP':
            manipulated_idxs, J_CMI, MIfy = icap(np.array(X), y, **params)
        elif config['method'] == 'JMI':
            manipulated_idxs, J_CMI, MIfy = jmi(np.array(X), y, **params)
        elif config['method'] == 'MIFS':
            manipulated_idxs, J_CMI, MIfy = mifs(np.array(X), y, **params)
        elif config['method'] == 'MIM':
            manipulated_idxs, J_CMI, MIfy = mim(np.array(X), y, **params)
        elif config['method'] == 'FAST-CMIM':
            manipulated_idxs, MIfy = fast_cmim(np.array(X), y, **params)
            J_CMI = None
            
    # Measure time
    end_time = time.time()
    
    # Define empty dataframe to store output metrics
    output_df = {}
    
    output_df['J_CMIM_'] = J_CMI
    output_df['MIfy_'] = MIfy
    
    # Store mask with 1 = corrupted, 0 = not corrupted
    output_df['mask_'] = np.zeros(reference.shape[1]).astype(int)
    output_df['mask_'][manipulated_idxs] = 1
    
    # Store total number of corrupted features detected as being corrupted
    output_df['n_corrupted_features_'] = np.sum(output_df['mask_'])
    
    # Store F1 Score in the filtering task and total runtime
    output_df['F1 Score'] = f1_score(y_true, output_df['mask_'], zero_division=1)
    output_df['Runtime'] = end_time - start_time
    
    return output_df