#!/usr/bin/env python
# coding: utf-8
import sys
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/configurations")
#########################################################################################

import torch
import numpy as np
import collections

from extract_metafeatures import extract
from dataloader_train import DatasetOpenML

class DatasetOpenML_variable_recommend(DatasetOpenML):
    overall_hc_mf = []

    def __getitem__(self, list_idx):
        nb_col = np.random.randint(low=3, high=11)
        nb_points = np.random.randint(low=700, high=900)

        list_X, list_y, list_params, list_loss, list_mf_hc = [], [], [], [], []
        for idx in list_idx:
            index = self.list_index[idx]
            X, y = self.list_X[index], self.list_y[index]
            size_x, size_y = X.shape

            count_class = collections.Counter(y)
            prob = [1/(count_class[v] * len(count_class)) for v in y]

            selected_row = np.random.choice(X.shape[0], nb_points, p=prob)
            selected_col = np.random.choice(X.shape[1], nb_col, replace=True)

            select_X = X[selected_row[:, None], selected_col]
            select_y = y[selected_row]

            list_X.append(select_X)
            list_y.append(select_y)
            hc_mf = extract(X=select_X, y=select_y, list_columns=self.list_cols_metafeatures)
            list_mf_hc.append(hc_mf)

            # params = self.config_space.sample_configuration()
            # list_params.append(list(self.encode_hp.transform([np.nan_to_num(params.get_array())])[0]))
            list_params.append(list(range(1000)))

            losses = []
            for i in range(1000):
                loss = self.objective(self.list_configurations[i].get_dictionary(), select_X, select_y)['loss']
                losses.append(loss)
            list_loss.append(losses)

        if len(self.overall_hc_mf) < 5000:
            self.overall_hc_mf.extend(list_mf_hc)
            self.mm_scaler.fit(self.overall_hc_mf)

        list_mf_hc = self.mm_scaler.transform(list_mf_hc)

        return torch.FloatTensor(list_X), \
                torch.FloatTensor(list_y), \
                torch.FloatTensor(list_params), \
                torch.FloatTensor(list_loss), \
                torch.FloatTensor(list_mf_hc)



class DatasetOpenML_variable(DatasetOpenML):
    overall_hc_mf = []

    def __getitem__(self, list_idx):
        nb_col = np.random.randint(low=3, high=11)
        nb_points = np.random.randint(low=700, high=900)

        list_X, list_y, list_params, list_loss, list_mf_hc = [], [], [], [], []
        for idx in list_idx:
            index = self.list_index[idx]
            X, y = self.list_X[index], self.list_y[index]
            size_x, size_y = X.shape

            count_class = collections.Counter(y)
            prob = [1/(count_class[v] * len(count_class)) for v in y]

            selected_row = np.random.choice(X.shape[0], nb_points, p=prob)
            selected_col = np.random.choice(X.shape[1], nb_col, replace=True)

            select_X = X[selected_row[:, None], selected_col]
            select_y = y[selected_row]

            list_X.append(select_X)
            list_y.append(select_y)
            hc_mf = extract(X=select_X, y=select_y, list_columns=self.list_cols_metafeatures)
            list_mf_hc.append(hc_mf)

            # params = self.config_space.sample_configuration()
            # list_params.append(list(self.encode_hp.transform([np.nan_to_num(params.get_array())])[0]))
            list_params.append(self.list_encoded_configurations[idx])

            loss = self.objective(self.list_configurations[idx].get_dictionary(), select_X, select_y)['loss']
            list_loss.append(loss)

        if len(self.overall_hc_mf) < 5000:
            self.overall_hc_mf.extend(list_mf_hc)
            self.mm_scaler.fit(self.overall_hc_mf)

        list_mf_hc = self.mm_scaler.transform(list_mf_hc)

        return torch.FloatTensor(list_X), \
                torch.FloatTensor(list_y), \
                torch.FloatTensor(list_params), \
                torch.FloatTensor(list_loss), \
                torch.FloatTensor(list_mf_hc)
