#!/usr/bin/env python
# coding: utf-8
import sys
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/configurations")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/dida")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/")
#########################################################################################

import glob

import copy
import torch
import openml
import hparams
import collections
import numpy as np
import pandas as pd

import global_variables
from dida_network import DIDA
from tqdm.notebook import tqdm
from collections import OrderedDict

from extract_metafeatures import extract
from torch.utils.data import Dataset, DataLoader, sampler

from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder


from hyperopt import hp
### HyperOpt Parameter Tuning
from hyperopt import tpe
from hyperopt import STATUS_OK
from hyperopt import Trials
from hyperopt import hp
from hyperopt import fmin

def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression

import hyperopt.pyll.stochastic
from sklearn.utils.testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.model_selection import StratifiedKFold




available_dataset = [40966, 40982, 40994, 40983, 40927, 40979, 4538, 40978, 1485, 4134, 1478, 1501, 6332, 4534, 1497, 1480, 1475, 1468, 1486, 40923, 40996, 41027, 40975, 40984, 1590, 40499, 23381, 1494, 1487, 1461, 1489, 40670, 40668, 1510, 40701, 23517, 1464, 300, 307, 11, 12, 6, 14, 15, 16, 18, 3, 188, 182, 23, 32, 28, 29, 22, 54, 151, 44, 46, 37, 38, 50, 1068, 31, 1462, 458, 1053, 1049, 1050, 554, 469, 1067, 1063]
list_files = ['1475_0', '4135_9', '1508_0', '1465_0', '793_0', '1557_1', '42016_0', '1540_0', '796_1', '41945_0', '974_0', '1015_0', '768_0', '792_0', '1538_0', '300_0', '1115_4', '811_1', '679_0', '740_0', '930_1', '467_0', '995_0', '893_0', '465_3', '882_0', '1491_0', '1006_15', '42046_0', '46_60', '38_22', '1510_0', '314_22', '311_0', '727_0', '758_0', '40709_2', '943_0', '40669_6', '1049_0', '346_0', '1524_0', '1413_0', '726_0', '1542_0', '1530_0', '749_0', '1061_0', '803_0', '919_0', '185_1', '40685_0', '181_0', '1463_5', '1490_0', '761_0', '41674_2', '878_0', '5_73', '807_0', '1513_0', '161_0', '40711_7', '54_0', '40704_0', '62_15', '40690_9', '874_0', '41679_2', '855_0', '1497_0', '935_0', '903_0', '41156_0', '41146_0', '60_0', '450_1', '459_0', '40498_0', '841_0', '42041_0', '734_0', '2_32', '61_0', '784_1', '1493_0', '942_1', '1067_0', '815_0', '1554_4', '151_1', '1488_0', '1063_0', '723_0', '959_8', '1046_0', '40650_20', '15_0', '808_0', '40496_0', '1512_0', '991_6', '1055_1', '988_1', '336_0', '756_0', '41946_0', '1528_0', '921_1', '719_4', '4534_30', '375_0', '1482_0', '35_33', '40678_7', '979_0', '24_22', '1121_3', '1444_0', '973_0', '1026_6', '1467_0', '1120_0', '795_0', '951_1', '934_4', '880_0', '949_1', '1167_1', '1056_0', '1527_0', '179_12', '31_13', '28_0', '37_0', '1073_0', '53_0', '889_0', '444_3', '40707_23', '860_0', '724_1', '184_6', '188_5', '987_2', '938_1', '762_0', '871_0', '1002_39', '21_6', '912_0', '917_0', '448_3', '1044_3', '1449_0', '774_0', '23499_9', '966_1', '1505_1', '267_0', '962_0', '783_0', '41228_0', '40708_23', '255_7', '827_0', '767_2', '964_1', '847_0', '843_0', '180_40', '40985_0', '40984_0', '750_0', '1556_5', '357_0', '1009_4', '40713_23', '1503_0', '34_8', '57_22', '1480_1', '472_0', '752_0', '40647_20', '780_0', '40648_20', '894_0', '342_3', '731_0', '785_0', '40994_0', '865_1', '4329_13', '40682_0', '1529_0', '1489_0', '1541_0', '12_0', '42036_0', '1568_8', '265_0', '925_0', '838_0', '162_0', '1075_0', '813_0', '772_0', '40701_4', '1455_5', '351_0', '335_0', '818_1', '1495_6', '829_0', '1450_0', '40981_8', '1560_0', '730_0', '40705_2', '1453_0', '862_2', '29_9', '1211_0', '42051_0', '1492_0', '4154_0', '1446_0', '1100_4', '833_0', '799_0', '41684_2', '751_0', '826_11', '49_7', '40660_11', '790_1', '40476_20', '42066_0', '733_0', '1555_3', '1483_2', '40686_12', '923_1', '464_0', '18_0', '897_2', '1441_0', '1000_22', '895_0', '763_0', '480_8', '40922_0', '1466_0', '41027_0', '754_0', '1543_0', '457_0', '285_3', '1219_0', '755_0', '40681_6', '1494_0', '835_4', '461_3', '846_0', '1048_0', '1499_0', '14_0', '41005_26', '1059_0', '1013_1', '40663_20', '955_2', '777_0', '26_8', '745_1', '42186_0', '717_0', '804_4', '479_3', '153_0', '928_1', '42071_0', '1459_0', '929_0', '1506_13', '822_0', '836_3', '469_4', '725_0', '735_0', '42021_0', '310_0', '1498_1', '1511_0', '1443_0', '1539_0', '775_0', '714_2', '187_0', '1552_4', '41007_26', '554_0', '13_9', '42_35', '41997_0', '23_7', '728_0', '927_0', '782_0', '764_2', '1016_2', '800_0', '1053_0', '729_0', '1597_0', '1487_0', '744_0', '40700_1', '41680_2', '1068_0', '950_1', '881_3', '6_0', '879_0', '1062_0', '916_0', '40499_0', '997_0', '1218_0', '1547_0', '776_0', '4538_0', '823_0', '869_0', '20_240', '868_0', '42098_0', '1523_0', '892_0', '884_0', '1018_41', '1532_0', '924_0', '40714_5', '468_0', '1600_0', '476_0', '1500_0', '40649_20', '4_8', '42056_0', '778_0', '913_0', '1544_0', '773_0', '1451_0', '1050_0', '1565_0', '845_0', '1526_0', '43_1', '55_13', '40668_42', '1504_0', '1460_0', '40475_20', '44_0', '41004_26', '1041_0', '339_1', '683_0', '40671_0', '1005_0', '820_0', '152_0', '340_3', '40497_0', '887_1', '1025_5', '1549_3', '891_2', '941_7', '3_36', '40999_26', '59_0', '36_0', '337_0', '307_2', '901_0', '40683_8', '42011_0', '896_0', '830_0', '870_0', '721_0', '787_0', '22_0', '753_0', '857_0', '715_0', '16_0', '789_0', '817_0', '1551_3', '40982_0', '1473_0', '1536_0', '40478_20', '41682_2', '1501_0', '30_0', '771_2', '182_0', '1060_0', '354_0', '42193_1', '41950_0', '40687_12', '812_0', '1014_4', '39_0', '40677_24', '1021_0', '41511_0', '748_2', '137_9', '926_0', '40710_8', '1496_0', '791_0', '1471_0', '1412_0', '259_0', '463_26', '1509_0', '1180_0', '1118_6', '1525_0', '1448_0', '952_0', '993_27', '933_0', '48_2', '902_4', '40693_9', '947_1', '40900_0', '908_0', '56_16', '41568_0', '41583_0', '1071_0', '42192_1', '40975_6', '40477_20', '746_0', '824_0', '41675_2', '1464_0', '23512_0', '1553_4', '41998_2', '40474_20', '40983_0', '1507_0', '890_1', '333_0', '848_1', '1116_1', '1012_26', '965_15', '40646_20', '694_0', '1570_0', '743_0', '41521_2', '42003_0', '910_0', '1011_0', '969_0', '41538_3', '936_0', '40691_0', '329_0', '900_0', '1217_0', '914_0', '40706_10', '819_0', '859_0', '41000_26', '251_0', '1117_0', '338_6', '1534_0', '119_9', '345_44', '1054_0', '42031_0', '42026_0', '915_3', '905_0', '9_10', '909_0', '931_0', '40664_0', '885_0', '682_0', '1559_0', '911_0', '1040_0', '759_0', '1533_0', '853_1', '736_0', '1442_0', '40998_26', '1567_0', '1535_0', '1545_0', '906_0', '1220_0', '40997_26', '886_0', '1019_0', '164_0', '976_0', '770_0', '816_0', '983_7', '958_0', '1065_0', '334_0', '1216_0', '737_0', '832_0', '40702_10', '40971_0', '446_0', '1537_0', '747_4', '1064_0', '825_3', '51_7', '1452_0', '41685_2', '41919_0', '40680_10', '475_0', '821_0', '23517_0', '849_0', '1069_0', '1481_3', '1462_0', '907_0', '41671_0', '1546_0', '875_3', '1036_0', '685_0', '462_1', '41_0', '50_9', '794_0', '11_0', '32_0', '10_15', '945_3', '996_0', '1461_9', '41544_2', '801_0', '814_0', '765_2', '40_0', '864_1', '741_1', '946_0', '977_0', '1558_9', '779_0', '1502_0', '863_0', '994_0', '1447_0', '722_0', '720_1', '713_0', '1531_0', '867_2']

id_to_use = []

for dt in available_dataset:
    id_of_dataset_to_use = None
    for i, d in enumerate(list_files):
        if int(d.split("_")[0]) == dt:
            id_of_dataset_to_use = i
            break
    if id_of_dataset_to_use is not None:
        id_to_use.append(id_of_dataset_to_use)



def get_list_dataset(test_size, seed):
    list_X = []
    list_y = []
    list_f_x = []

    for id_dt in id_to_use:
        f_x = "/linkhome/rech/genini01/uvp29is/Code/metanal/datasets/openml/clean_data/{}.x.npy".format(list_files[id_dt])
        f_y = "/linkhome/rech/genini01/uvp29is/Code/metanal/datasets/openml/clean_data/{}.y.npy".format(list_files[id_dt])
        f_feat = "/linkhome/rech/genini01/uvp29is/Code/metanal/datasets/openml/clean_data/{}.feat.npy".format(list_files[id_dt])

        X_init = np.load(f_x)
        y = np.load(f_y)
        feat = np.load(f_feat)

        ct = ColumnTransformer([("categorical", OneHotEncoder(handle_unknown='ignore'),
                                 [i for i, col_type in enumerate(feat) if col_type == 1])], remainder="passthrough")

        X = ct.fit_transform(X_init)


        if np.unique(y).shape[0] == 2:
            list_X.append(X)
            list_y.append(y)
            list_f_x.append(f_x)
    train_idx, test_idx = train_test_split(list(range(len(list_X))), random_state=seed, test_size=test_size)
    return list_X, list_y, train_idx, test_idx

def check_nb_category(num_i, feat_name, dataset):
    return dataset[feat_name[num_i]].value_counts().shape[0] <= 10

def get_list_dataset_cc18(test_size, seed):
    list_X = []
    list_y = []
    list_f_x = []

    list_dataset = [3, 6, 11, 12, 14, 15, 16, 18, 22, 23, 28, 29, 31, 32, 37, 44, 46, 50, 54, 151, 182, 188, 38, 307, 300, 458, 469, 554, 1049, 1050, 1053, 1063, 1067, 1068, 1590, 4134, 1510, 1489, 1494, 1497, 1501, 1480, 1485, 1486, 1487, 1468, 1475, 1462, 1464, 4534, 6332, 1461, 4538, 1478, 23381, 40499, 40668, 40966, 40982, 40994, 40983, 40975, 40984, 40979, 40996, 41027, 23517, 40923, 40978, 40670, 40701]

    for id_dt in tqdm(list_dataset):
        d = openml.datasets.get_dataset(id_dt)
        X_init, _, feat_cat, feat_name = d.get_data()
        del feat_cat[feat_name.index(d.default_target_attribute)]
        y_init = X_init[d.default_target_attribute]
        le = preprocessing.LabelEncoder()
        y = le.fit_transform(y_init)
        X_init.drop([d.default_target_attribute], axis=1, inplace=True)
        dtypes = X_init.dtypes

        # X_init = X_init.loc[:, (X_init != X_init.iloc[0]).any()]

        for col, has_nan in X_init.isna().any().iteritems():
            if X_init[col].nunique() == 0:
                X_init[col] = 0
            elif has_nan:
                if not (X_init[col].dtype == np.float64 or X_init[col].dtype == np.int64):
                    X_init[col] = X_init[col].fillna(X_init[col].value_counts().index[0])
                else:
                    X_init[col] = X_init[col].fillna(X_init[col].median())


        ct = ColumnTransformer([("categorical", OneHotEncoder(handle_unknown='ignore', sparse=False),
                                 [i for i, col_type in enumerate(feat_cat) if (col_type and check_nb_category(i, feat_name, X_init))]),
                           ("nominal", preprocessing.MinMaxScaler(),
                                 [i for i, col_type in enumerate(feat_cat) if not col_type])], remainder="drop")

        X = ct.fit_transform(X_init)


        # if np.unique(y).shape[0] == 2:
        list_X.append(X)
        list_y.append(y)
    train_idx, test_idx = train_test_split(list(range(len(list_X))), random_state=seed, test_size=test_size)
    return list_X, list_y, train_idx, test_idx


class DatasetOpenML(Dataset):

    def __init__(self, list_X, list_y, configs, config_space, size_training, dida_model=None, handcrafted_mf=None):
        self.list_X = list_X
        self.list_y = list_y
        self.configs = configs
        self.dida_model = dida_model
        self.classifier = configs.training.classifier.name
        self.seed = configs.training.seed
        self.npoints = configs.training.dataloader.npoints
        self.choice_solver = ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga']

        self.list_index = np.random.randint(len(self.list_X), size=size_training)
        self.config_space = config_space
        self.mm_scaler = preprocessing.MinMaxScaler()
        self.mm_scaler.fit(handcrafted_mf)
        self.list_cols_metafeatures = handcrafted_mf.columns

        # Encode HParams
        list_cat_hp = [i for i, param in enumerate(self.config_space.get_hyperparameters())
                        if (hasattr(param, "choices") and len(param.choices) > 2)]

        self.encode_hp = ColumnTransformer([("categorical", OneHotEncoder(handle_unknown='ignore'),
                                list_cat_hp)],
                                 remainder="passthrough")
        self.list_configurations = [self.config_space.sample_configuration() for _ in range(size_training)]
        self.list_encoded_configurations = self.encode_hp.fit_transform(np.nan_to_num([c.get_array() for c in self.list_configurations]))

    def __len__(self):
        return len(self.list_index)


    def convert_hp_dict_to_array(self, config):
        return self.encode_hp.transform(np.nan_to_num([config.get_array()]))[0]

    @ignore_warnings(category=ConvergenceWarning)
    def objective(self, params, X, y, n_folds = 3):
        """Objective function Hyperparameter Tuning"""

        # Perform n_fold cross validation with hyperparameters
        # Use early stopping and evaluate based on ROC AUC
        cv = StratifiedKFold(random_state=self.seed, n_splits=n_folds, shuffle=True)
        clf = hparams.init_model(self.classifier, params, self.seed)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            scores = cross_val_score(clf, X, y, cv=cv, scoring='balanced_accuracy')

        # Extract the best score
        mean_scores = np.mean(scores)

        # Loss must be minimized
        loss = 1 - mean_scores

        # Dictionary with information for evaluation
        return {'loss': 1 if np.isnan(loss) else loss}

    @ignore_warnings(category=ConvergenceWarning)
    def objective_cv(self, params, X, y, n_folds = 3):
        """Objective function Hyperparameter Tuning"""

        # Perform n_fold cross validation with hyperparameters
        # Use early stopping and evaluate based on ROC AUC
        cv = StratifiedKFold(random_state=self.seed, n_splits=n_folds, shuffle=True)
        clf = hparams.init_model(self.classifier, params, self.seed)
        scores = cross_val_score(clf, X, y, cv=cv, scoring='balanced_accuracy')

        # Extract the best score
        return np.nan_to_num(scores)

    def set_surrogate_model(self, surrogate, x_surrogate, y_surrogate, list_X, list_y, list_dida_mf):
        self.surrogate = surrogate
        self.x_surrogate = x_surrogate
        self.y_surrogate = y_surrogate
        # self.list_X = list_X
        # self.list_y = list_y
        # self.list_dida_mf = list_dida_mf
        # if model_gpu is not None:
        #     best_model_state_dict = {k:v.to('cpu') for k, v in model_gpu.state_dict().items()}
        #     best_model_state_dict = OrderedDict(best_model_state_dict)
        #     global_variables.model_cpu.load_state_dict(best_model_state_dict)
            # self.dida_model = model_gpu.dida


    def __getitem__(self, list_idx):
        min_col = min([self.list_X[self.list_index[i]].shape[1] for i in list_idx])

        list_X, list_y, list_params, list_loss, list_mf_hc = [], [], [], [], []
        for idx in list_idx:
            index = self.list_index[idx]
            X, y = self.list_X[index], self.list_y[index]
            size_x, size_y = X.shape

            count_class = collections.Counter(y)
            prob = [1/(count_class[v] * len(count_class)) for v in y]

            selected_row = np.random.choice(X.shape[0], self.npoints, p=prob)
            selected_col = np.random.choice(X.shape[1], min_col, replace=False)

            select_X = X[selected_row[:, None], selected_col]
            select_y = y[selected_row]

            list_X.append(select_X)
            list_y.append(select_y)
            hc_mf = extract(X=select_X, y=select_y, list_columns=self.list_cols_metafeatures)
            list_mf_hc.append(hc_mf)

            # params = self.config_space.sample_configuration()
            # list_params.append(list(self.encode_hp.transform([np.nan_to_num(params.get_array())])[0]))
            list_params.append(self.list_encoded_configurations[idx])

            loss = self.objective(self.list_configurations[idx].get_dictionary(), select_X, select_y)['loss']
            list_loss.append(loss)

        return torch.FloatTensor(list_X), \
                torch.FloatTensor(list_y), \
                torch.FloatTensor(list_params), \
                torch.FloatTensor(list_loss), \
                torch.FloatTensor(self.mm_scaler.fit_transform(list_mf_hc))



class DatasetOpenMLCC18(Dataset):

    def __init__(self, list_X, list_y, npoints, seed, nb_sample_dataset):
        self.list_X = list_X
        self.list_y = list_y
        self.seed = seed
        self.npoints = npoints

        self.list_index = np.random.randint(len(self.list_X), size=nb_sample_dataset)

    def __len__(self):
        return len(self.list_index)


    def __getitem__(self, list_idx):
        min_col = min([self.list_X[self.list_index[i]].shape[1] for i in list_idx])

        list_X, list_y, list_type, list_loss, list_mf_hc = [], [], [], [], []
        for idx in list_idx:
            index = self.list_index[idx]
            X, y = self.list_X[index], self.list_y[index]
            size_x, size_y = X.shape

            n_0, n_1 =  1 / (2 * sum(y ==0)), 1 / (2 * sum(y == 1))
            prob = [n_0 if v == 0 else n_1 for v in y]

            selected_row = np.random.choice(X.shape[0], self.npoints, p=prob)
            selected_col = np.random.choice(X.shape[1], min_col, replace=False)

            select_X = X[selected_row[:, None], selected_col]
            select_y = y[selected_row]

            list_X.append(select_X)
            list_y.append(select_y)

        return list_X, list_y
