################################################################################
# Script to run all experiments with DF-Locate.
################################################################################

import sys
import os
import numpy as np
import pandas as pd
import random
import itertools
import pickle

from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor

import lightgbm as lgb
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.metrics import f1_score
from catboost import CatBoostClassifier

import time

sys.path.append('../')

from src.datafix import DFLocate

from src.preprocessing._transform_data import (
    reference_query_split, 
    indexes_to_manipulate,
    manipulate_features,
    impute_features,
    compute_rmse_of_manipulation
)

from src.io._io import save_pkl

########################################

# USER INPUTS
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# Define path to data
data_path = '../../data/'
que_imp_path = '../../output/correct/'

# Define output file path to store results
output_path = '../../output/locate/'

# Define path to MLP networks
mlp_path = '../../data/mlp_models/'

# Define the names of datasets with continuous/categorical features
continuous_datasets = ['gas', 'covid', 'energy', 'musk2', 'scene', 'mnist', 'cosine', 'polynomial', 'dilbert']
categorical_datasets = ['phenotypes', 'founders', 'embark']

simulated_datasets = [
    'bernoulli_shift_10000_1', 'bernoulli_shift_500_1', 'bernoulli_shift_1000_1', 
    'bernoulli_shift_5000_1', 'corr_mvg_mean_shift_1', 'bmm_collapse-means_0.7_1', 
    'bmm_shift-one-mixture_1', 'dmvg_var_shift_1', 'dmvg_mean_shift_1', 
    'gmm_shift-one-mixture_1', 't_exp_corr_mvg_mean_shift_1', 
    't_exp_corr_mvg_feat_shuffle_1', 't_sig_corr_mvg_feat_shuffle_1', 
    't_sig_corr_mvg_mean_shift_1', 't_sig_corr_mvg_sample_shuffle_1'
]

simulated_imputed_queries = [
    'adversarial_v1_que_imp', '10nn_que_imp', 'DD10_dom_adap_que_imp', 'INB200_dom_adap_que_imp', 
    'LinearRegression_supervised_que_imp', 'MLPRegressor_supervised_que_imp',
    'gain_hi_que_imp', 'hyperimpute_hi_que_imp', 'ice_hi_que_imp',
    'mean_hi_que_imp', 'miracle_hi_que_imp', 'missforest_hi_que_imp',
    'sinkhorn_hi_que_imp', 'softimpute_hi_que_imp'
]

# Define the manipulations on continous/categorical features
continuous_features = [1.0, 2.0, 3.0, 4.1, 4.2, 4.3, 5.0, 7.0, 8.0,
                       KNeighborsRegressor(n_neighbors=1),
                       KNeighborsRegressor(n_neighbors=5)]
categorical_features = [2.0, 3.0, 6.1, 6.2, 6.3, 7.0, 8.0, 
                        KNeighborsClassifier(n_neighbors=1),
                        KNeighborsClassifier(n_neighbors=5)]

# Define data parameters
data_params = {
    "dataset" : continuous_datasets,
    "que_imp" : [None],
    "fraction" : [0.05, 0.1, 0.25],
    "maxStd" : [False],
    "type_manipulation" : continuous_features
}

# Define parameters for ``shift_location()``
# # RandomForestClassifier(random_state=0),
# LogisticRegression(random_state=0, penalty='l1', solver='liblinear', max_iter=5000)
# LinearSVC(random_state=0, penalty='l1', dual=False, max_iter=5000)
# lgb.LGBMClassifier(random_state=0, n_jobs=-1, importance_type='gain')
# ExtraTreesClassifier(random_state=0, n_jobs=-1)
# CatBoostClassifier(random_state=0)
shift_location_params = {
    "estimator" : [RandomForestClassifier(random_state=0)],
    "cv" : [5],
    "test_size" : [0.2],
    "scoring" : [['balanced_accuracy', 'f1', 'TV', 'KL', 'JD']],
    "n_jobs" : [-1],
    "return_estimator" : [False],
    "step" : [None],
    "percentage" : [0.1],
    "alpha" : [1],
    "threshold" : [None],
    "margin" : [0.01],
    "max_it" : [None],
    "max_features_to_filter" : [0.5],
    "patience" : [None],
    "random_state" : [0],
    "find_best" : ['knee-balanced'],
    "window_length" : [2],
    "polyorder" : [4],
    "S" : [5],
    "online" : [False]
}

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# Extracting keys and values from data_params dictionary
data_keys, data_values = zip(*data_params.items())

# Creating list of dictionaries containing all possible combinations of data 
# parameter values
data_experiments = [dict(zip(data_keys, v)) 
                             for v in itertools.product(*data_values)]

# Extracting keys and values from shift_location_params dictionary
shift_location_keys, shift_location_values = zip(*shift_location_params.items())

# Creating list of dictionaries containing all possible combinations of shift 
# location parameter values
shift_location_experiments = [dict(zip(shift_location_keys, v)) 
                             for v in itertools.product(*shift_location_values)]

# Remove experiments where both "alpha" and "threshold" are defined or
# both "step" and "threshold" are defined 
# Only one of each pair of parameters can be defined at a time
shift_location_experiments = [x for x in shift_location_experiments if 
                not all(x.get(key) for key in ("alpha", "threshold")) and
                not all(x.get(key) for key in ("step", "percentage"))]

def check_experiment_pending(output_path, data_config, shift_location_config):
    """
    Check if a previous shift_location.pkl file exists and if the current 
    experiment configuration has already been run or is pending of execution.

    Attributes
    ----------
    output_path : str
        Path where the shift_location.pkl file is stored.
    data_config : dict
        Configuration parameters for the data used in the experiment.
    shift_location_config : dict
        Configuration parameters for the shift location used in the experiment.
    
    Returns
    -------
    bool
        True if the experiment already has not already been run, False otherwise.
    """
    # Check if a previous shift_location.pkl file exists and load it
    if os.path.isfile(f'{output_path}{data_config["dataset"]}/shift_location.pkl'):
        # Define columns used to check if the experiment is already run or not
        columns = list(data_config.keys()) + list(shift_location_config.keys())

        # Load shift_location.pkl as dataframe
        with open(f'{output_path}{data_config["dataset"]}/shift_location.pkl', "rb") as f:
            shift_location_df = pickle.load(f)
        
        # Select columns and convert all cells to strings
        shift_location_df = shift_location_df[columns].astype(str)
        
        # Define dataframe with configuration of the new experiment and convert all cells to strings
        # Remove fraction column from dataframe
        experiment_data = list(data_config.values()) + list(shift_location_config.values())
        experiment_df = pd.DataFrame([experiment_data], columns=columns).astype(str)

        # Remove column rmse_manipulation
        shift_location_df = shift_location_df.drop(['rmse_manipulation'], axis=1)
        experiment_df = experiment_df.drop(['rmse_manipulation'], axis=1)
        
        # Return true if the row in experiment_df does not exist in shift_location_df
        return not sum((shift_location_df == experiment_df.loc[0]).all(axis=1)) == 1

    return True

# Iterate over different data configurations
for data_config in data_experiments:
    
    if data_config['dataset'] in simulated_datasets:
        assert data_config['fraction'] == None and data_config['type_manipulation'] == None, 'Wrong parameters!'
    
    # Read data from given data_path
    if data_config["dataset"] in simulated_datasets:
        if data_config["que_imp"] in simulated_imputed_queries:
            print(f'Reading ref and que imp from{que_imp_path}{data_config["dataset"]}_benchmark_results.pkl')
            with open(f'{que_imp_path}{data_config["dataset"]}_benchmark_results.pkl', "rb") as f:
                data = pickle.load(f)
            reference = data["original_dataset"]["ref"]
            query = data[data_config["que_imp"]]   
            manipulated_idxs = range(data["original_dataset"]["n_corrupted"])
        else:
            print(f'Reading ref and que from {data_path}{data_config["dataset"]}.pkl')
            with open(f'{data_path}{data_config["dataset"]}.pkl', "rb") as f:
                data = pickle.load(f)
            reference, query = data["ref"], data["que"]
            manipulated_idxs = range(data["n_corrupted"])
    else:
        data = np.load(f'{data_path}{data_config["dataset"]}.npy', allow_pickle=True)
    
        # Split the data into reference and query datasets with 50% of rows each
        reference, query = reference_query_split(data=data, random_state=0)
        
        # Obtain the indexes of the features to manipulate in the query dataset
        manipulated_idxs = indexes_to_manipulate(query=query, fraction=data_config["fraction"], 
                                                 maxStd=data_config["maxStd"], 
                                                 random_state=0)

    # Create a copy of the query dataset before manipulation
    query_clean = query.copy()
    
    if len(manipulated_idxs) > 0 and data_config["dataset"] not in simulated_datasets:
        
        # Apply the selected transformation to the query dataset
        if isinstance(data_config["type_manipulation"], (int, float)):
            if data_config["type_manipulation"] == 7.0:
                mlp_full_path = f'{mlp_path}{data_config["dataset"]}_{data_config["fraction"]}.pt'
            else:
                mlp_full_path = None
            
            if data_config["dataset"] in continuous_datasets:
                assert data_config["type_manipulation"] in continuous_features,\
                f'{data_config["type_manipulation"]} is a categorical feature manipulaion.'
                
            elif data_config["dataset"] in categorical_datasets:
                assert data_config["type_manipulation"] in categorical_features,\
                f'{data_config["type_manipulation"]} is a categorical feature manipulaion.'
            
            # Manipulate the features for the given manipulation type
            query[:,manipulated_idxs] = manipulate_features(query_Y=query[:,manipulated_idxs], 
                                                            transformation=data_config["type_manipulation"], 
                                                            mlp_path=mlp_full_path,
                                                            random_state=0)
        else:
            # Impute the missing values in the query dataset based on the reference dataset
            query[:,manipulated_idxs] = impute_features(reference=reference, query=query, 
                                                        manipulated_idxs=manipulated_idxs, 
                                                        model=data_config["type_manipulation"])
    
    # Measure the level of corruption of each feature in the query by computing 
    # the root mean squared error (RMSE) of each feature before/after manipulation
    data_config['rmse_manipulation'] = compute_rmse_of_manipulation(query_clean, query)
    
    # Define true filtering labels where original features are assigned label 0, 
    # and manipulated features are assigned label 1
    data_config['y_true'] = np.zeros(reference.shape[1])
    data_config['y_true'][manipulated_idxs] = 1
    
    # Define number of corrupted features
    data_config['n_corrupted_features'] = sum(data_config['y_true'])
    
    # Iterate over different shift localization configurations
    for shift_location_config in shift_location_experiments:        
        # Check if the experiment is pending of execution
        if not check_experiment_pending(output_path, data_config, shift_location_config):
            print('Experiment already exists.')
        else:
            # Execute experiment
            # print(data_config)
            print(shift_location_config)
            # Create a DFLocate object with the current shift location configuration
            datafix = DFLocate(**shift_location_config)
            datafix = datafix.shift_location(reference=reference, query=query)
            
            # Save performance in dataframe
            shift_location_cols = list(data_config.keys()) + list(vars(datafix).keys())
            shift_location_data = list(data_config.values()) + list(vars(datafix).values())
            shift_location_df = pd.DataFrame([shift_location_data], columns=shift_location_cols)
            
            # Store dataframe in pickle format
            save_pkl(f'{output_path}{data_config["dataset"]}', 'shift_location.pkl', shift_location_df)