#!/usr/bin/env python
# coding: utf-8
import sys
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/configurations")
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/")
#########################################################################################

import torch
import copy
import collections
import numpy as np
import global_variables
from scipy.stats import norm

from sklearn import preprocessing
from scipy.stats import ttest_ind
import torch.multiprocessing as mp
from extract_metafeatures import extract
from dataloader_train import DatasetOpenML

def get_metafeatures(model):
    # Construct data_loader, optimizer, etc.
    for data, labels in data_loader:
        optimizer.zero_grad()
        loss_fn(model(data), labels).backward()
        optimizer.step()


class DatasetOpenML_variable_ranking(DatasetOpenML):
    overall_hc_mf = []

    def add_scores(self, list_params_1, list_params_2,
                            loss_1, loss_2,
                            encoded_conf_1, encoded_conf_2,
                            list_loss, list_score):
        if np.random.rand() <= 0.5:
            if loss_1 <= loss_2:
                list_params_1.append(encoded_conf_1)
                list_params_2.append(encoded_conf_2)
                list_loss.append(0)
                list_score.append([loss_1, loss_2])
            else:
                list_params_1.append(encoded_conf_2)
                list_params_2.append(encoded_conf_1)
                list_loss.append(0)
                list_score.append([loss_2, loss_1])
        else:
            if loss_1 <= loss_2:
                list_params_1.append(encoded_conf_2)
                list_params_2.append(encoded_conf_1)
                list_loss.append(1)
                list_score.append([loss_2, loss_1])
            else:
                list_params_1.append(encoded_conf_1)
                list_params_2.append(encoded_conf_2)
                list_loss.append(1)
                list_score.append([loss_1, loss_2])

    def get_second_hp(self, select_X, select_y, loss_1):
        loss_2 = loss_1.copy()
        stat, p = ttest_ind(loss_1, loss_2)
        n = 0
        found = False
        while not (p < 0.5 or found):
            conf_2 = self.config_space.sample_configuration()
            encoded_conf_2 = self.convert_hp_dict_to_array(conf_2)
            loss_2 = self.objective_cv(conf_2.get_dictionary(), select_X, select_y, n_folds=3)
            stat, p = ttest_ind(loss_1, loss_2)
            n += 1
            if n>=20 or np.mean(loss_2) == 0:
                found = True
        return loss_2, conf_2, encoded_conf_2, found

    def __getitem__(self, list_idx):
        nb_col = np.random.randint(low=3, high=11)
        nb_points = np.random.randint(low=700, high=900)

        list_X, list_y, list_params_1, list_params_2, list_loss, list_mf_hc, list_score = [], [], [], [], [], [], []
        for idx in list_idx:
            index = self.list_index[idx]
            X, y = self.list_X[index], self.list_y[index]
            size_x, size_y = X.shape

            count_class = collections.Counter(y)
            prob = [1/(count_class[v] * len(count_class)) for v in y]

            selected_row = np.random.choice(X.shape[0], nb_points, p=prob)
            selected_col = np.random.choice(X.shape[1], nb_col, replace=True)

            select_X = X[selected_row[:, None], selected_col]
            select_y = y[selected_row]
            hc_mf = extract(X=select_X, y=select_y, list_columns=self.list_cols_metafeatures)

            conf_1, encoded_conf_1 = self.list_configurations[idx].get_dictionary(), self.list_encoded_configurations[idx]

            loss_1 = self.objective_cv(conf_1, select_X, select_y, n_folds=3)
            if np.mean(loss_1) == 0: continue

            loss_2, conf_2, encoded_conf_2, found = self.get_second_hp(select_X, select_y, loss_1)
            if found: continue

            list_X.append(select_X)
            list_y.append(select_y)
            list_mf_hc.append(hc_mf)
            loss_1 = np.mean(loss_1)
            loss_2 = np.mean(loss_2)
            self.add_scores(list_params_1, list_params_2, loss_1,
                            loss_2, encoded_conf_1, encoded_conf_2,
                            list_loss, list_score)

        max_y = 0
        for y in list_y:
            current_y = len(collections.Counter(y))
            if max_y < current_y:
                max_y = current_y
        new_list = []
        for y in list_y:
            new_y = np.zeros((y.shape[0], max_y))
            ohc = preprocessing.OneHotEncoder(sparse=False)
            y = ohc.fit_transform(y.reshape(-1,1))
            new_y[:, :y.shape[1]] = y
            new_list.append(new_y)

        if len(self.overall_hc_mf) < 5000:
            self.overall_hc_mf.extend(list_mf_hc)
            self.mm_scaler.fit(self.overall_hc_mf)

        list_mf_hc = self.mm_scaler.transform(list_mf_hc)

        return torch.FloatTensor(list_X), \
                torch.FloatTensor(new_list), \
                torch.FloatTensor(list_params_1), \
                torch.FloatTensor(list_params_2), \
                torch.FloatTensor(list_loss), \
                torch.FloatTensor(list_mf_hc), \
                torch.FloatTensor(list_score)



class ImprovedDataset(DatasetOpenML_variable_ranking):
    def __init__(self, list_X, list_y, list_dida_mf, list_hc_mf,
                 config_space, encode_hp,
                 mm_scaler, surrogate,
                 x_surrogate, seed, classifier,
                 list_cols_metafeatures,
                 size_training):
        self.seed = seed
        self.list_X = list_X
        self.list_y = list_y
        self.list_hc_mf = list_hc_mf
        self.list_dida_mf = list_dida_mf
        self.surrogate = surrogate
        self.x_surrogate = x_surrogate
        self.classifier = classifier
        self.encode_hp = encode_hp
        self.mm_scaler = mm_scaler
        self.config_space = config_space
        self.list_cols_metafeatures = list_cols_metafeatures
        self.list_index = np.random.randint(len(self.list_X), size=size_training)

    def __len__(self):
        return len(self.list_index)

    def get_std_surrogate(self, X):
        list_res = [m.predict(X) for m in self.surrogate.estimators_]
        list_res = np.stack(list_res, axis=1)
        return np.mean(list_res, axis=1), np.std(list_res, axis=1)

    def get_first_hp(self, metafeatures):
        # Bayesian Optimization
        if np.random.rand() < 0.65:
            config = self.config_space.sample_configuration()
            return config.get_dictionary(), self.convert_hp_dict_to_array(config)

        list_configurations = [self.config_space.sample_configuration() for _ in range(1000)]
        list_encoded_configurations = self.encode_hp.transform([self.convert_hp_dict_to_array(c) for c in list_configurations])
        metafeatures = np.tile(metafeatures, (1000, 1))

        assert metafeatures.shape[0] == list_encoded_configurations.shape[0]
        X_candidate = np.concatenate([metafeatures, list_encoded_configurations], axis=1)

        # Copy from http://krasserm.github.io/2018/03/21/bayesian-optimization/
        mu, sigma = self.get_std_surrogate(X_candidate)
        mu_sample = self.surrogate.predict(self.x_surrogate)

        sigma = sigma.reshape(-1)
        mu_sample_opt = np.min(mu_sample)

        with np.errstate(divide='warn'):
            imp = mu_sample_opt - mu
            Z = imp / sigma
            ei = imp * norm.cdf(Z) + sigma * norm.pdf(Z)
            ei[sigma == 0.0] = 0.0

        assert len(ei) == list_encoded_configurations.shape[0]
        selected_configuration = np.argmax(ei)
        return list_configurations[selected_configuration].get_dictionary(), list_encoded_configurations[selected_configuration]

    def __getitem__(self, idx):
        list_X, list_y, list_params_1, list_params_2, list_loss, list_mf_hc, list_score = [], [], [], [], [], [], []
        idx = self.list_index[idx]
        for select_X, select_y, dida_mf,hc_mf in zip(self.list_X[idx], self.list_y[idx], self.list_dida_mf[idx], self.list_hc_mf[idx]):

            conf_1, encoded_conf_1 = self.get_first_hp(dida_mf)
            loss_1 = self.objective_cv(conf_1, select_X, select_y, n_folds=3)
            if np.mean(loss_1) == 0: continue

            loss_2, conf_2, encoded_conf_2, found = self.get_second_hp(select_X, select_y, loss_1)
            if found: continue

            list_X.append(select_X)
            list_y.append(select_y)
            list_mf_hc.append(hc_mf)
            loss_1 = np.mean(loss_1)
            loss_2 = np.mean(loss_2)

            self.add_scores(list_params_1, list_params_2, loss_1,
                            loss_2, encoded_conf_1, encoded_conf_2,
                            list_loss, list_score)

        list_mf_hc = self.mm_scaler.transform(list_mf_hc)

        return torch.FloatTensor(list_X), \
                torch.FloatTensor(list_y), \
                torch.FloatTensor(list_params_1), \
                torch.FloatTensor(list_params_2), \
                torch.FloatTensor(list_loss), \
                torch.FloatTensor(list_mf_hc), \
                torch.FloatTensor(list_score)
