import sys
sys.path.append('../ite/')
sys.path.append('../src/')
sys.path.append('../')
sys.path.append('ite/')
sys.path.append('src/')
import ite
import ot
import os
#import autosklearn.classification
import pickle as pkl
import numpy as np
from sklearn.model_selection import cross_validate, cross_val_predict
from sklearn import datasets, linear_model
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis ### todo review this!!!
from snapml import BoostingMachineClassifier as SnapBoostingMachineClassifier
from sklearn.svm import SVC, LinearSVC
from sklearn.ensemble import ExtraTreesClassifier
import lightgbm as lgb
from sklearn.experimental import enable_iterative_imputer 
from sklearn.impute import KNNImputer, IterativeImputer
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.neural_network import MLPRegressor
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
from sklearn.svm import SVC
from sklearn.kernel_approximation import RBFSampler, Nystroem
from warnings import simplefilter
from sklearn.exceptions import ConvergenceWarning
simplefilter("ignore", category=ConvergenceWarning)
import time
from sklearn import pipeline
import ite.cost as ic
from src.divergences.dp_div import dp_div
import torch
from smartsvm.divergence import divergence as hp_divergence
from smartsvm import hp_estimate
from autosklearn.experimental.askl2 import AutoSklearn2Classifier
import math

from src.imputers.domain_adaptation_wrapper import inb_wrapper, dd_wrapper
from src.imputers.supervised_imputer import SupervisedImputer
from src.imputers.adversarial_imputer_v2 import AdversarialIterativePerSampleImputerV2 as AdversarialImputer
from src.imputers.adversarial_imputer_v2 import AdversarialIterativeBatchWrapper
from src.divergences.ot_distances import ot_dist
from src.divergences.discriminator_divergence_estimator import div_multiple_clf, div_tv_clf, div_all_multiple_clf
from src.divergences.ite_divergence import ite_div
from src.divergences.ber_chevy.wrapper import ber_chevy_estimator
from src.divergences.feebee_wrapper import ghp_ber, kde_ber, knn_ber
from src.divergences.knn_divergence import naive_estimator, scipy_estimator, skl_estimator, skl_efficient, fast_estimator
from src.transform_data import (reference_query_split, indexes_to_manipulate, 
                                    manipulate_features, impute_features, 
                                    compute_rmse_of_manipulation)

from sklearn.model_selection import cross_val_predict, cross_val_score
import ray
from sklearn.impute import KNNImputer
from hyperimpute.plugins.imputers import Imputers

sys.path.append('../../')
sys.path.append('../../locate/')

import json
import argparse

# Set up command-line argument parsing
parser = argparse.ArgumentParser(description="Read and print a JSON file.")
parser.add_argument("path", help="Path to the JSON file")

# Parse command-line arguments
args = parser.parse_args()


# Read the JSON file
with open(args.path, "r") as file:
    input_dict = json.load(file)


STORE_REF_QUE_IN_DICT = True
LOAD_DICT_IF_AVAILABLE = True


def dp2_div(reference, query, sqrt=True):
    if sqrt:
        return np.sqrt(hp_divergence(reference, query))
    else:
        return hp_divergence(reference, query)


def kl_knn_wrapper(reference, query, estimator=scipy_estimator, ks=[3,], symmetric=True):
    kl = np.mean([estimator(reference, query, k=k) for k in ks])
    if symmetric:
        kl += np.mean([estimator(query, reference, k=k) for k in ks])
    return kl


def feature_subset_wrapper(function, reference, query, n_corrupted, max_feature_size=25000, n_iters=10):
    
    if reference.shape[1] < max_feature_size:
        return function(reference, query)
    
    else:
        print('DEBUGGING HERE ', n_corrupted)
        outputs = []
        for i in range(n_iters):
            perm = np.random.RandomState(seed=i).permutation(reference.shape[1])
            outs = function(reference[:,perm][:,0:int(reference.shape[1]/10)], query[:,perm][:,0:int(reference.shape[1]/10)])

            if type(outs) == np.float64:
                outputs.append(np.array(outs))
            else:
                outputs.append(np.array(list(outs)))
                
        outputs = np.stack(outputs)
        out = np.mean(outputs, axis=0)
        if type(out) == np.float64:
            return out
        else:
            return list(out)
                


def batched_wrapper(function, reference, query, n_corrupted, max_batch_size=10000):
    if reference.shape[0] < max_batch_size:
        return feature_subset_wrapper(function, reference, query, n_corrupted)
    else:
        perm = np.random.RandomState(seed=42).permutation(reference.shape[0])
        num_batches = math.ceil(reference.shape[0] / max_batch_size)
        outputs = []
        for i in range(num_batches):
            s, e = i*max_batch_size, (i+1)*max_batch_size
            
            if i == num_batches-1:
                s, e = reference.shape[0] - max_batch_size, reference.shape[0]
            
            outs = feature_subset_wrapper(function, reference[perm,:][s:e,:], query[perm,:][s:e,:], n_corrupted)

            if type(outs) == np.float64:
                outputs.append(np.array(outs))
            else:
                outputs.append(np.array(list(outs)))
                
        outputs = np.stack(outputs)
        out = np.mean(outputs, axis=0)
        if type(out) == np.float64:
            return out
        else:
            return list(out)


def evaluation_report(reference, query, n_corrupted, query_original=None):
    print('performing evaluation metrics...')
    reference = np.double(reference)
    query = np.double(query)
    import warnings
    warnings.filterwarnings("ignore", message="Variables are collinear")
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        
        t1 = time.time()
        swd, w2 =  batched_wrapper(ot_dist, reference, query, n_corrupted, max_batch_size=10000)
        t2 = time.time()
        print('\n\nAAA', swd, w2, t2-t1, '\n\n')
        dp2 = batched_wrapper(dp2_div, reference, query, n_corrupted, max_batch_size=1000)
        t3 = time.time()
        print('\n\nBBB', t3-t2, '\n\n')
        t4 = time.time()
        ited = batched_wrapper(ite_div, reference, query, n_corrupted, max_batch_size=500)
        t5 = time.time()
        print('\n\nCCC', t5-t4, '\n\n')
        t6 = time.time()
        #metrics_div = 0 #div_all_multiple_clf(reference, query, clf_list=None, verbose=True)
        t7 = time.time()
        #print('\n\nDDD', t7-t6, '\n\n')
        #ghp_l, ghp_u = 0, 0 #ghp_ber(reference, query, verbose=True)
        t8 = time.time()
        #knn_l, knn_u = knn_ber(reference, query, verbose=True)
        t9 = time.time()
        #print('\n\nEEE', t9-t8, '\n\n')
        #kl_knn_2 = 0 #kl_knn_wrapper(reference, query, scipy_estimator)
        t10 = time.time()
        
        if query_original is not None:
            mse = np.mean(np.square(query - query_original))
        else:
            mse = 0
        
        metric_times_log = {'w2':t2-t1,'dp2':t3-t2,'dp':t4-t3,'ited':t5-t4,'ber_c':t6-t5,'metrics_div':t7-t6, 'ghp_ber':t8-t7, 'knn_ber':t9-t8, 'knn_kla':t10-t9}
    
    results_dict = {'swd':swd, 'w2':w2, 'dp2':dp2,  'ited':ited, 'metric_times_log':metric_times_log, 'mse':mse}
    
    print(f' swd: {swd:.4f}, w2: {w2:.4f}, dp2: {dp2:.4f}, ite: {ited:.4f},  mse: {mse:.4f}, '\
          f'metric_times_log: {metric_times_log}, total eval time is: {t10-t1}')
    
    return results_dict




def store_dict(dict_obj, path):
    if LOAD_DICT_IF_AVAILABLE and os.path.isfile(path):
        print('loading existing dict ... ', path)
        with open(path, 'rb') as handle:
            results_dataset = pkl.load(handle)
            
        for k in dict_obj.keys():
            results_dataset[k] = dict_obj[k]
            
        with open(path, 'wb') as handle:
            pkl.dump(results_dataset, handle, protocol=pkl.HIGHEST_PROTOCOL)
    else:
        with open(path, 'wb') as handle:
            pkl.dump(dict_obj, handle, protocol=pkl.HIGHEST_PROTOCOL)


        
def update_and_save_results_dataset_log(name, results_dataset, output_path, query_imputed, compute_time, results_dict):
    if STORE_REF_QUE_IN_DICT:
        results_dataset[f'{name}_que_imp'] = query_imputed
    results_dataset[f'{name}_time'] = compute_time
    results_dataset[f'{name}_metrics_dict'] = results_dict
    store_dict(results_dataset, output_path)
    return results_dataset
        

run_domain_adaptation = True
run_supervised = True
run_ours = True
run_hyperimpute = True
run_background = True

root_datasets = '../../data/'

### Simulation datasets
all_sim_datasets = [
    'gmm_shift-one-mixture_1.pkl', 'bernoulli_shift_5000_1.pkl', 'corr_mvg_mean_shift_1.pkl', 
    't_exp_corr_mvg_feat_shuffle_1.pkl', 'dmvg_var_shift_1.pkl', 't_sig_corr_mvg_sample_shuffle_1.pkl', 
    'bernoulli_shift_10000_1.pkl', 'bernoulli_shift_500_1.pkl', 't_sig_corr_mvg_mean_shift_1.pkl', 
    'bmm_collapse-means_0.7_1.pkl', 'dmvg_mean_shift_1.pkl', 't_exp_corr_mvg_mean_shift_1.pkl', 
    'bernoulli_shift_1000_1.pkl', 'bmm_shift-one-mixture_1.pkl', 't_sig_corr_mvg_feat_shuffle_1.pkl']


### Real datasets
continuous_datasets =  ['gas', 'covid', 'energy', 'musk2', 'scene',  'mnist', 'cosine', 'polynomial', 'dilbert']
categorical_datasets = ['phenotypes', 'founders', 'embark']
all_real_datasets = ['gas', 'covid', 'energy', 'musk2', 'scene',  'mnist', 'cosine', 'polynomial', 'dilbert', 'phenotypes', 'founders', 'embark']

datasets_to_process = input_dict["datasets"]
methods_to_process = input_dict["methods"]

if "max_dims" in input_dict.keys():
    max_dims = int(input_dict["max_dims"])
else:
    max_dims = None
    
if "base_classifier_name" in input_dict.keys():
    base_classifier_name = input_dict["base_classifier_name"]
else:
    base_classifier_name = "CatBoost"
    
print('Input dict is ', input_dict)

output_dir = '../../output/correct/'

if not os.path.exists(output_dir):
    # If the output directory does not already exist, create it
    os.makedirs(output_dir)


def get_simulated_dataset(root_dir, dataset_file):
    dset_path = root_dir+dataset_file
    dset_name = dset_path.split('/')[-1].replace('.pkl','')

    with open(dset_path, 'rb') as handle:
        dataset = pkl.load(handle)

    reference = dataset['ref'] #[0::2,:]
    query = dataset['que'] #[0::2,:]
    query_missing = query.copy()
    query_missing[:,0:dataset['n_corrupted']] = np.nan
    print(reference.shape, query.shape, dataset['n_corrupted'])

    return dset_path, dset_name, dataset, reference, query, None, query_missing, dataset['n_corrupted']



def get_real_dataset(root_dir, dataset_file):
    fraction = 0.25

    dset_path = f'{root_dir}{dataset_file}.npy'
    dset_name = dataset_file
    # Read data from given data_path
    dataset = np.load(dset_path, allow_pickle=True)
    
    # Split the data into reference and query datasets with 50% of rows each
    reference, query = reference_query_split(dataset)

    # Obtain the indexes of the features to manipulate in the query dataset
    manipulated_idxs = indexes_to_manipulate(query, fraction, None)
    n_corrupted = len(manipulated_idxs)
    manipulated_idxs = np.array(manipulated_idxs)
    
    query_missing = query.copy()
    query_missing[:, manipulated_idxs] = np.nan
    
    # Do sorting
    isnan = np.isnan(query_missing[0,:])
    print(query_missing.shape, query.shape, reference.shape, isnan.shape)
    query_missing = np.concatenate([query_missing[:,isnan], query_missing[:,~isnan]], axis=1)
    query = np.concatenate([query[:,isnan], query[:,~isnan]], axis=1)
    reference = np.concatenate([reference[:,isnan], reference[:,~isnan]], axis=1)
    
    if reference.shape[0]<query.shape[0]:
        reference = np.concatenate([reference, reference[0,:][np.newaxis,:]])
        print(query_missing.shape, query.shape, reference.shape, isnan.shape)
        assert reference.shape[0] == query.shape[0]
    
    query_original = query
    
    return dset_path, dset_name, dataset, reference, query, query_original, query_missing, n_corrupted


def run_benchmark_in_dataset(root_dir, dataset_file, output_dir, methods_to_process, base_classifier_name):
    
    line_s = '--------------------------------------------------------------\n'
    print(line_s,line_s,line_s,line_s,line_s,line_s,line_s,line_s)
    print('Running benchmark for ', root_dir, dataset_file, output_dir)
    print(line_s,line_s,line_s,line_s)

    if dataset_file in all_real_datasets:
        dset_path, dset_name, dataset, reference, query, query_original, query_missing, n_corrupted = get_real_dataset(
            root_dir, dataset_file
        )
    else:
        dset_path, dset_name, dataset, reference, query, query_original, query_missing, n_corrupted = get_simulated_dataset(
            root_dir, dataset_file
        )
    
    output_name = dset_name+'_benchmark_results.pkl'
    output_path = output_dir+output_name

    results_dataset = {}
    
    if STORE_REF_QUE_IN_DICT:
        results_dataset['original_dataset'] = dataset
    results_dataset['dataset_path'] = dset_path
    num_feat_to_modify = n_corrupted

    print('start evaluating imputers')
    
    ## Background divergence
    if run_background and 'background' in methods_to_process:
        print('Running background divergence evaluation')
        metrics_dict = evaluation_report(reference, query, n_corrupted, query_original=query_original)
        results_dataset = update_and_save_results_dataset_log(
            f'background0_original', results_dataset, output_path, None, 0, metrics_dict
        )
        metrics_dict = evaluation_report(
            reference[:,num_feat_to_modify:], query[:,num_feat_to_modify:], 0
        )
        results_dataset = update_and_save_results_dataset_log(f'background1_filtererd', results_dataset, output_path, None, 0, metrics_dict)
        ref0 = reference[0::2,:]
        ref1 = reference[1::2,:]
        if ref1.shape[0] < ref0.shape[0]:
            ref1 = np.concatenate([ref1, ref1[0,:][np.newaxis,:]])
        metrics_dict = evaluation_report(ref0, ref1, n_corrupted,)
        results_dataset = update_and_save_results_dataset_log(
            f'background2_ref-ref', results_dataset, output_path, None, 0, metrics_dict
        )
        del ref0, ref1

    ## Domain adaptatin
    if run_domain_adaptation:
        if 'INB' in methods_to_process:
            for n in [50,100,200]:
                clf_name = 'INB'+str(n)
                print(f'Evaluating Domain adaptation : {clf_name}  ----------------- ')
                tic = time.time()
                query_imputed = inb_wrapper(reference.copy(), query_missing.copy(), n_layers = n)
                toc = time.time()

                metrics_dict = evaluation_report(reference, query_imputed, n_corrupted, query_original=query_original)
                results_dataset = update_and_save_results_dataset_log(
                    f'{clf_name}_dom_adap', results_dataset, output_path, query_imputed, toc-tic, metrics_dict
                )
            
        if 'DD' in methods_to_process:    
            for n in [10,50]:
                clf_name = 'DD'+str(n)
                print(f'Evaluating Domain adaptation : {clf_name}  ----------------- ')
                tic = time.time()
                query_imputed = dd_wrapper(reference.copy(), query_missing.copy(), n_canonical_destructors=n)
                toc = time.time()

                metrics_dict = evaluation_report(reference, query_imputed, n_corrupted, query_original=query_original)
                results_dataset = update_and_save_results_dataset_log(
                    f'{clf_name}_dom_adap', results_dataset, output_path, query_imputed, toc-tic, metrics_dict
                )


    ## Supervised Imputer
    if run_supervised:
        if 'Regressor' in methods_to_process:
            for clf in [LinearRegression(), MLPRegressor()]:
                clf_name = str(type(clf)).split(".")[-1][:-2]
                print(f'Evaluating Supervised : {clf_name}  ----------------- ')
                tic = time.time()
                imputer = SupervisedImputer(clf, num_feat_to_modify)
                imputer.fit(reference.copy())
                query_imputed = imputer.transform(query_missing.copy())
                toc = time.time()

                metrics_dict = evaluation_report(reference, query_imputed, n_corrupted, query_original=query_original)
                results_dataset = update_and_save_results_dataset_log(
                    f'{clf_name}_supervised', results_dataset, output_path, query_imputed, toc-tic, metrics_dict
                )


        ## KNN
        if 'nn' in methods_to_process:
            print('Evaluating knn regression  ----------------- ')
            for k in [10]:
                print(f'Eval of k={k}')
                tic = time.time()
                imputer = KNNImputer(n_neighbors=k)
                imputer.fit(reference.copy())
                query_imputed = imputer.transform(query_missing.copy())
                toc = time.time()

                metrics_dict = evaluation_report(reference, query_imputed, n_corrupted, query_original=query_original)
                results_dataset = update_and_save_results_dataset_log(
                    f'{k}nn', results_dataset, output_path, query_imputed, toc-tic, metrics_dict
                )

    ## Adversarial (OURS)
    if run_ours and 'adversarial' in methods_to_process:
        tic = time.time()
        query_zeros = query_missing.copy()
        query_zeros[np.isnan(query_zeros)] = 0
        
        if base_classifier_name == 'CatBoost':
            base_classifier = CatBoostClassifier(verbose=False)
        elif base_classifier_name == 'RF':
            base_classifier = RandomForestClassifier(random_state=0)
        elif base_classifier_name == 'Log Reg.':
            base_classifier = LogisticRegression(random_state=0, penalty='l1', solver='liblinear', max_iter=5000)
        elif base_classifier_name == 'SVC':
            #base_classifier = LinearSVC(random_state=0, penalty='l1', dual=False, max_iter=5000)
            base_classifier = SVC(kernel='linear', probability=True, random_state=0, max_iter=5000)
        elif base_classifier_name == 'ExtraTree':
            base_classifier = ExtraTreesClassifier(random_state=0, n_jobs=-1)
        elif base_classifier_name == 'LGBM':
            base_classifier = lgb.LGBMClassifier(random_state=0, n_jobs=-1, importance_type='gain')
        else:
            raise ValueError(f'{base_classifier_name} not supported.')
        
        ADVIPSI = AdversarialIterativeBatchWrapper(None, base_classifier, reference.copy(), 
                                                   query_zeros.copy(), num_feat_to_modify, max_dims=max_dims)
        
        del query_zeros
        
        query_imputed = ADVIPSI.fit_transform()
        
        toc = time.time()
        
        metrics_dict = evaluation_report(reference, query_imputed, n_corrupted, query_original=query_original)
        
        print(metrics_dict)
        
        results_dataset = update_and_save_results_dataset_log(
            f'adversarial_v1_{base_classifier_name}', results_dataset, output_path, query_imputed, toc-tic, metrics_dict
        )


    ## Hyper Imputers library
    if run_hyperimpute:
        imputers = Imputers()
        imputers_list = ['mean', 'median', 'most_frequent', 'hyperimpute',  'gain', 'miracle', 
                         'ice', 'softimpute', 'missforest', 'sinkhorn']
        for imp in imputers_list:
            if imp in methods_to_process:
                print(f'running imputer ... {imp}')
                tic = time.time()
                data_x = np.concatenate([np.double(reference).copy(), np.double(query_missing).copy()])
                imputer = Imputers().get(imp)
                imputed_data = imputer.fit_transform(data_x.copy())
                print(imputed_data.shape)
                query_imputed = np.double(np.array(imputed_data))[reference.shape[0]:,:]
                query_imputed[:,num_feat_to_modify:] = np.double(query)[:,num_feat_to_modify:]
                toc = time.time()

                metrics_dict = evaluation_report(reference, query_imputed, n_corrupted, query_original=query_original)
                results_dataset = update_and_save_results_dataset_log(
                    f'{imp}_hi', results_dataset, output_path, query_imputed, toc-tic, metrics_dict
                )


for dataset_file in datasets_to_process: 
    root_dir = root_datasets
    run_benchmark_in_dataset(root_dir, dataset_file, output_dir, methods_to_process, base_classifier_name)
