################################################################################
# Script to run all experiments to find manipulated features between a 
# reference and a query dataset with benchmarking techniques.
################################################################################

import sys
import os
import numpy as np
import pandas as pd
import random
import itertools
import pickle

from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor

sys.path.append('../')

from src.preprocessing._transform_data import (
    reference_query_split, 
    indexes_to_manipulate,
    manipulate_features,
    impute_features,
    compute_rmse_of_manipulation
)
from src.preprocessing._data_manipulation import _convert_to_dataframe
from src.io._io import save_pkl
from benchmarking._benchmarking_methods import (
    filter_selectKbest, 
    filter_mutual_information, 
    filter_feature_shift_detection,
    filter_scikit_feature
)

########################################

# USER INPUTS
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# Define path to data
data_path = '../../data/'

# Define output file path to store results
output_path = '../../output/locate/'

# Define path to MLP networks
mlp_path = '../../data/mlp_models/'

# Define the names of datasets with continuous/categorical features
continuous_datasets = ['gas', 'covid', 'energy', 'musk2', 'scene', 'mnist', 'cosine', 'polynomial', 'dilbert']
categorical_datasets = ['phenotypes', 'founders', 'embark']

simulated_datasets = [
    'bernoulli_shift_10000_1', 'bernoulli_shift_500_1', 'bernoulli_shift_1000_1', 
    'bernoulli_shift_5000_1', 'corr_mvg_mean_shift_1', 'bmm_collapse-means_0.7_1', 
    'bmm_shift-one-mixture_1', 'dmvg_var_shift_1', 'dmvg_mean_shift_1', 
    'gmm_shift-one-mixture_1', 't_exp_corr_mvg_mean_shift_1', 
    't_exp_corr_mvg_feat_shuffle_1', 't_sig_corr_mvg_feat_shuffle_1', 
    't_sig_corr_mvg_mean_shift_1', 't_sig_corr_mvg_sample_shuffle_1'
]

# Define the manipulations on continous/categorical features
continuous_features = [1.0, 2.0, 3.0, 4.1, 4.2, 4.3, 5.0, 7.0, 8.0,
                       KNeighborsRegressor(n_neighbors=5)]
categorical_features = [2.0, 3.0, 6.1, 6.2, 6.3, 7.0, 8.0, 
                        KNeighborsClassifier(n_neighbors=5)]

# Define name of methods to run
methods = ['selectKbest', 'MI', 'MB-SM', 'MB-KS', 'KNN-KS', 'DD-SM', 'MRMR', 'FAST-CMIM']

# Define data parameters
# mnist true
data_params = {
    "dataset" : continuous_datasets, 
    "fraction" : [0.05, 0.1, 0.25],
    "maxStd" : [False],
    "type_manipulation" : continuous_features
}

# Define parameters for selectKbest method
selectKbest_params = {
    "method" : ['selectKbest'],
    "significance_level" : [0.008, 0.01, 0.02, 0.05, 0.1],
    "output_name" : ['selectKbest.pkl']
}

# Define parameters for MI method
mutual_information_params = {
    "threshold" : [0.008, 0.01, 0.02, 0.03, 0.05, 0.1],
    "random_state" : [0],
    "output_name" : ['mutual_information.pkl']
}

# Define parameters for 'MB-SM', 'MB-KS', 'KNN-KS', 'DD-SM' methods
# is_n_selected_features_specified=False: 'X_boot=reference(50%), Y_boot=reference(50%), X=reference, Y=query'
# is_n_selected_features_specified=True: 'X_boost=X=reference, Y_boost=Y=query'
feature_shift_params = {
    "partition" : ['X_boot=reference(50%), Y_boot=reference(50%), X=reference, Y=query'], 
    "is_n_selected_features_specified" : [False],
    "n_selected_features" : [None],
    "n_expectation" : [30],
    "n_neighbors" : [100],
    "n_bootstrap_runs" : [25],
    "random_state" : [0],
    "output_name" : ['feature_shift.pkl']
}

# Define parameters for 'MRMR', 'CMIM', 'CIFE', 'DISR', 'ICAP', 'JMI', 'MIFS', 'MIM', 'FAST-CMIM'
scikit_feature_params = {
    "is_n_selected_features_specified" : [False],
    "n_selected_features" : [None],
    "output_name" : ['scikit_feature.pkl']
}

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# Extracting keys and values from data_params dictionary
data_keys, data_values = zip(*data_params.items())

# Creating list of dictionaries containing all possible combinations of data 
# parameter values
data_experiments = [dict(zip(data_keys, v)) 
                             for v in itertools.product(*data_values)]

# Obtain list with all experiment configurations for all methods
all_experiments = []
for method in methods:
    # Extracting keys and values from method dictionary
    if method == 'selectKbest':
        selectKbest_params['method'] = [method]
        method_keys, method_values = zip(*selectKbest_params.items())
    elif method == 'MI':
        mutual_information_params['method'] = [method]
        method_keys, method_values = zip(*mutual_information_params.items())
    elif method in ['MB-SM', 'MB-KS', 'KNN-KS', 'DD-SM']:
        feature_shift_params['method'] = [method]
        method_keys, method_values = zip(*feature_shift_params.items())
    elif method in ['MRMR', 'CMIM', 'CIFE', 'DISR', 'ICAP', 'JMI', 'MIFS', 'MIM', 'FAST-CMIM']:
        scikit_feature_params['method'] = [method]
        method_keys, method_values = zip(*scikit_feature_params.items())
    else:
        raise ValueError(f'{method} is not supported.')
    
    # Creating list of dictionaries containing all possible combinations of 
    # selectKbest parameter values
    method_experiments = [dict(zip(method_keys, v)) 
                                 for v in itertools.product(*method_values)]
    
    # Append experiment for a particular method to list with all experiments
    all_experiments += method_experiments

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

def check_experiment_pending(output_file, data_config, method_config):
    """
    Check if a previous output file exists and if the current 
    experiment configuration has already been run or is pending of execution.

    Attributes
    ----------
    output_file : str
        Path to output pkl file.
    data_config : dict
        Configuration parameters for the data used in the experiment.
    shift_location_config : dict
        Configuration parameters for the method used in the experiment.
    
    Returns
    -------
    bool
        True if the experiment already has not already been run, False otherwise.
    """
    # Check if a previous shift_location.pkl file exists and load it
    if os.path.isfile(output_file):
        
        # Define columns used to check if the experiment is already run or not
        columns = list(data_config.keys()) + list(method_config.keys())
        
        # Load shift_location.pkl as dataframe
        with open(output_file, "rb") as f:
            method_df = pickle.load(f)
        
        # Select columns and convert all cells to strings
        method_df = method_df[columns].astype(str)
        
        # Define dataframe with configuration of the new experiment and convert all cells to strings
        # Remove fraction column from dataframe
        experiment_data = list(data_config.values()) + list(method_config.values())
        experiment_df = pd.DataFrame([experiment_data], columns=columns).astype(str)

        # Remove column rmse_manipulation
        method_df = method_df.drop(['rmse_manipulation'], axis=1)
        experiment_df = experiment_df.drop(['rmse_manipulation'], axis=1)
        
        # Return true if the row in experiment_df does not exist in shift_location_df
        return not sum((method_df == experiment_df.loc[0]).all(axis=1)) == 1

    return True

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# Iterate over different data configurations
for data_config in data_experiments:
    
    if data_config['dataset'] in simulated_datasets:
        assert data_config['fraction'] == None and data_config['type_manipulation'] == None, 'Wrong parameters!'
    
    # Read data from given data_path
    if data_config["dataset"] in simulated_datasets:
        with open(f'{data_path}{data_config["dataset"]}.pkl', "rb") as f:
            data = pickle.load(f)
    else:
        data = np.load(f'{data_path}{data_config["dataset"]}.npy', allow_pickle=True)
    
    if data_config["dataset"] in simulated_datasets:
        reference, query = data["ref"], data["que"]
    else:  
        # Split the data into reference and query datasets with 50% of rows each
        reference, query = reference_query_split(data=data, random_state=0)

    if data_config["dataset"] in simulated_datasets:
        manipulated_idxs = range(data['n_corrupted'])
    else:
        # Obtain the indexes of the features to manipulate in the query dataset
        manipulated_idxs = indexes_to_manipulate(query=query, fraction=data_config["fraction"], 
                                                 maxStd=data_config["maxStd"], 
                                                 random_state=0)

    # Create a copy of the query dataset before manipulation
    query_clean = query.copy()
    
    if len(manipulated_idxs) > 0 and data_config["dataset"] not in simulated_datasets:
        
        # Apply the selected transformation to the query dataset
        if isinstance(data_config["type_manipulation"], (int, float)):
            if data_config["type_manipulation"] == 7.0:
                mlp_full_path = f'{mlp_path}{data_config["dataset"]}_{data_config["fraction"]}.pt'
            else:
                mlp_full_path = None
            
            if data_config["dataset"] in continuous_datasets:
                assert data_config["type_manipulation"] in continuous_features,\
                f'{data_config["type_manipulation"]} is a categorical feature manipulaion.'
                
            elif data_config["dataset"] in categorical_datasets:
                assert data_config["type_manipulation"] in categorical_features,\
                f'{data_config["type_manipulation"]} is a categorical feature manipulaion.'
            
            # Manipulate the features for the given manipulation type
            query[:,manipulated_idxs] = manipulate_features(query_Y=query[:,manipulated_idxs], 
                                                            transformation=data_config["type_manipulation"], 
                                                            mlp_path=mlp_full_path,
                                                            random_state=0)
        else:
            # Impute the missing values in the query dataset based on the reference dataset
            query[:,manipulated_idxs] = impute_features(reference=reference, query=query, 
                                                        manipulated_idxs=manipulated_idxs, 
                                                        model=data_config["type_manipulation"])
    
    # Measure the level of corruption of each feature in the query by computing 
    # the root mean squared error (RMSE) of each feature before/after manipulation
    data_config['rmse_manipulation'] = compute_rmse_of_manipulation(query_clean, query)
    
    # Define true filtering labels where original features are assigned label 0, 
    # and manipulated features are assigned label 1
    data_config['y_true'] = np.zeros(reference.shape[1])
    data_config['y_true'][manipulated_idxs] = 1
    
    # Define number of corrupted features
    data_config['n_corrupted_features'] = sum(data_config['y_true'])
    
    # Convert reference and query datasets to dataframe
    # such that the name of the columns are the index of each column
    reference, query = _convert_to_dataframe(reference, query)
    
    # Iterate over different experiment configurations
    for experiment_config in all_experiments:
        
        print(experiment_config)
        
        if experiment_config['method'] == 'selectKbest':
            if data_config['dataset'] in categorical_datasets:
                experiment_config['score_func'] = 'chi2'
                assert set(np.unique(reference)) == {0, 1}, 'Wrong chi2!'
            elif data_config['dataset'] in continuous_datasets:
                experiment_config['score_func'] = 'f_classif'
                assert set(np.unique(reference)) != {0, 1}, 'Wrong f_classif!'
            else:
                if set(np.unique(reference)) == {0, 1}:
                    experiment_config['score_func'] = 'chi2'
                else:
                    experiment_config['score_func'] = 'f_classif'
        
        elif experiment_config['method'] in ['MB-SM', 'MB-KS', 'KNN-KS', 'DD-SM']:
            
            if experiment_config['partition'] == 'X_boost=X=reference, Y_boost=Y=query':
                assert experiment_config['is_n_selected_features_specified'], 'Wrong parameters!'
                
            if experiment_config['partition'] == 'X_boot=reference(50%), Y_boot=reference(50%), X=reference, Y=query':
                assert experiment_config['is_n_selected_features_specified'] == False, 'Wrong parameters!'
            
            # If the parameter 'is_n_selected_features_specified' is set, 
            # the number of compromised features will be provided to the models
            # Otherwise, it is assumed that all features are potentially compromised
            if experiment_config['is_n_selected_features_specified']:
                experiment_config['n_selected_features'] = int(data_config['n_corrupted_features'])
            else:
                experiment_config['n_selected_features'] = int(reference.shape[1])
            
            if experiment_config['method'] in ['MB-SM', 'MB-KS', 'DD-SM']:
                experiment_config['n_neighbors'] = None
                
            if data_config['dataset'] in ['gas', 'covid', 'energy', 'musk2']:
                if experiment_config['method'] in ['MB-SM', 'MB-KS', 'DD-SM']:
                    #assert experiment_config['n_bootstrap_runs'] == 250, 'Wrong parameters!'
                    pass
                if experiment_config['method'] == 'KNN-KS':
                    #assert experiment_config['n_bootstrap_runs'] == 50, 'Wrong parameters!'
                    pass
            elif data_config['dataset'] == 'scene':
                if experiment_config['method'] in ['MB-SM', 'DD-SM']:
                    #assert experiment_config['n_bootstrap_runs'] == 50, 'Wrong parameters!'
                    pass
                if experiment_config['method'] in ['KNN-KS', 'MB-KS']:
                    #assert experiment_config['n_bootstrap_runs'] == 50, 'Wrong parameters!'
                    pass
            else:
                if experiment_config['method'] in ['MB-SM', 'DD-SM']:
                    #assert experiment_config['n_bootstrap_runs'] == 50, 'Wrong parameters!'
                    pass
                if experiment_config['method'] in ['KNN-KS', 'MB-KS']:
                    #assert experiment_config['n_bootstrap_runs'] in [20,5, 10], 'Wrong parameters!'
                    pass
                
        
        elif experiment_config['method'] in ['MRMR', 'CMIM', 'CIFE', 'DISR', 'ICAP', 'JMI', 'MIFS', 'MIM', 'FAST-CMIM']:
            # If the parameter 'is_n_selected_features_specified' is True, the 
            # number of compromised features is explicitly provided to the models
            # If it is False, the number of selected features is determined by 
            # thresholding and is not specified beforehand
            if experiment_config['is_n_selected_features_specified']:
                experiment_config['n_selected_features'] = int(data_config['n_corrupted_features'])
            else:
                experiment_config['n_selected_features'] = None
        
        # Define name to output file
        output_file = f'{output_path}{data_config["dataset"]}/{experiment_config["output_name"]}'
        
        # If the experiment is pending of execution
        if not check_experiment_pending(output_file, data_config, experiment_config):
            print('Experiment already exists.')
        else:          
            print(data_config['dataset'], data_config['fraction'], data_config['maxStd'], data_config['type_manipulation'])
            print(experiment_config)
            
            if experiment_config['method'] == 'selectKbest':
                # Run selectKBest with the current configuration
                output_dict = filter_selectKbest(
                    reference, query, data_config['y_true'], experiment_config
                )
            elif experiment_config['method'] == 'MI':
                output_dict = filter_mutual_information(
                    reference, query, data_config['y_true'], experiment_config
                )
            elif experiment_config['method'] in ['MB-SM', 'MB-KS', 'KNN-KS', 'DD-SM']:
                output_dict = filter_feature_shift_detection(
                    reference, query, data_config['y_true'], experiment_config
                )
            elif experiment_config['method'] in ['MRMR', 'CMIM', 'CIFE', 'DISR', 'ICAP', 'JMI', 'MIFS', 'MIM', 'FAST-CMIM']:
                output_dict = filter_scikit_feature(
                    reference, query, data_config['y_true'], experiment_config
                )
            
            experiment_cols = list(data_config.keys()) + list(experiment_config.keys()) + list(output_dict.keys())
            experiment_data = list(data_config.values()) + list(experiment_config.values()) + list(output_dict.values())
            experiment_df = pd.DataFrame([experiment_data], columns=experiment_cols)
            
            save_pkl(f'{output_path}{data_config["dataset"]}/', experiment_config["output_name"], experiment_df)
