import sys
sys.path.append("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/configurations")
###################################################################################
import hparams
import pandas as pd
import pickle
from dataloader_train import DatasetOpenML
from d2v_dataloader import D2V
from dataloader_train_variable import DatasetOpenML_variable_recommend, DatasetOpenML_variable
from dataloader_train_variable_ranking import DatasetOpenML_variable_ranking, ImprovedDataset

from torch.utils.data import DataLoader, sampler


def get_dataloader(list_X_train, list_y_train, list_X_test, list_y_test, cfg):
    config_space = hparams.get_search_sapce(cfg.training.seed, cfg.training.classifier)
    handcrafted_mf = pd.read_csv("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/checkpoints/metafeatures_handcrafted_cc18.csv", index_col='index')

    dataset_train = DatasetOpenML_variable(list_X=list_X_train, list_y=list_y_train,
                              configs=cfg,
                              size_training=cfg.training.dataloader.nb_sample_dataset,
                              config_space=config_space,
                              handcrafted_mf=handcrafted_mf)
    dataset_test = DatasetOpenML_variable(list_X=list_X_test, list_y=list_y_test,
                              configs=cfg,
                              size_training=cfg.training.dataloader.nb_sample_dataset_test,
                              config_space=config_space,
                              handcrafted_mf=handcrafted_mf)
    sampler_train = sampler.BatchSampler(
        sampler.RandomSampler(dataset_train),
        batch_size=cfg.training["dataloader"]["batch_size"],
        drop_last=cfg.training["dataloader"]["drop_last"])
    sampler_test = sampler.BatchSampler(
        sampler.RandomSampler(dataset_test),
        batch_size=cfg.training["dataloader"]["batch_size"],
        drop_last=cfg.training["dataloader"]["drop_last"])
    dataloader_train = DataLoader(dataset=dataset_train,
                            num_workers=cfg.training["dataloader"]["num_workers"],
                            sampler=sampler_train)
    dataloader_test = DataLoader(dataset=dataset_test,
                            num_workers=cfg.training["dataloader"]["num_workers"],
                            sampler=sampler_test)
    return dataloader_train, dataloader_test


def get_recommend_dataloader(list_X_test, list_y_test, cfg):
    config_space = hparams.get_search_sapce(cfg.training.seed, cfg.training.classifier)

    dataset_test = DatasetOpenML_variable_recommend(list_X=list_X_test, list_y=list_y_test,
                              configs=cfg,
                              size_training=cfg.training.dataloader.nb_sample_dataset_test,
                              config_space=config_space)
    sampler_test = sampler.BatchSampler(
        sampler.RandomSampler(dataset_test),
        batch_size=cfg.training["dataloader"]["batch_size"],
        drop_last=cfg.training["dataloader"]["drop_last"])
    dataloader_test = DataLoader(dataset=dataset_test,
                            num_workers=cfg.training["dataloader"]["num_workers"],
                            sampler=sampler_test)
    return dataloader_test


def get_dataloader_ranking(list_X_train, list_y_train, list_X_test, list_y_test, cfg):
    config_space = hparams.get_search_sapce(cfg.training.seed, cfg.training.classifier)
    handcrafted_mf = pd.read_csv("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/checkpoints/metafeatures_handcrafted_cc18.csv", index_col='index')


    dataset_train = DatasetOpenML_variable_ranking(list_X=list_X_train, list_y=list_y_train,
                              configs=cfg,
                              size_training=cfg.training.dataloader.nb_sample_dataset,
                              config_space=config_space,
                              handcrafted_mf=handcrafted_mf)
    dataset_test = DatasetOpenML_variable_ranking(list_X=list_X_test, list_y=list_y_test,
                              configs=cfg,
                              size_training=cfg.training.dataloader.nb_sample_dataset_test,
                              config_space=config_space,
                              handcrafted_mf=handcrafted_mf)

    sampler_train = sampler.BatchSampler(
        sampler.RandomSampler(dataset_train),
        batch_size=cfg.training["dataloader"]["batch_size"],
        drop_last=cfg.training["dataloader"]["drop_last"])
    sampler_test = sampler.BatchSampler(
        sampler.RandomSampler(dataset_test),
        batch_size=cfg.training["dataloader"]["batch_size"],
        drop_last=cfg.training["dataloader"]["drop_last"])
    dataloader_train = DataLoader(dataset=dataset_train,
                            num_workers=cfg.training["dataloader"]["num_workers"],
                            sampler=sampler_train)
    dataloader_surrogate = DataLoader(dataset=dataset_train,
                            num_workers=cfg.training["dataloader"]["num_workers"],
                            sampler=sampler_test)
    dataloader_test = DataLoader(dataset=dataset_test,
                            num_workers=cfg.training["dataloader"]["num_workers"],
                            sampler=sampler_test)
    return dataloader_train, dataloader_test, dataloader_surrogate

def get_BO_dataloader(list_X, list_y, list_dida_mf, list_hc_mf,
                            previous_dataloader,
                            surrogate,
                            x_surrogate, cfg):
    handcrafted_mf = pd.read_csv("/linkhome/rech/genini01/uvp29is/Code/metanal_v2/checkpoints/metafeatures_handcrafted_cc18.csv", index_col='index')
    dataset_train = ImprovedDataset(list_X=list_X,
                                    list_y=list_y,
                                    list_dida_mf=list_dida_mf,
                                    list_hc_mf=list_hc_mf,
                                    config_space=previous_dataloader.dataset.config_space,
                                    encode_hp=previous_dataloader.dataset.encode_hp,
                                    mm_scaler=previous_dataloader.dataset.mm_scaler,
                                    surrogate=surrogate,
                                    classifier=previous_dataloader.dataset.classifier,
                                    size_training=1000,
                                    x_surrogate=x_surrogate,
                                    seed=previous_dataloader.dataset.seed,
                                    list_cols_metafeatures=handcrafted_mf.columns)

    dataloader_train = DataLoader(dataset=dataset_train,
                            batch_size=1, shuffle=True,
                            num_workers=cfg["dataloader"]["num_workers"])
    return dataloader_train


def get_d2v_dataloader(cfg):
    dataset_train = D2V(cfg.training.dataset,
                        cfg.training.seed,
                        test=False,
                        only_handcrafted=(True if cfg.training.extractor == "handcrafted" else False))
    dataset_test = D2V(cfg.training.dataset,
                        cfg.training.seed,
                        test=True,
                        only_handcrafted=(True if cfg.training.extractor == "handcrafted" else False))

    sampler_train = sampler.BatchSampler(
        sampler.RandomSampler(dataset_train),
        batch_size=cfg.training["dataloader"]["batch_size"],
        drop_last=cfg.training["dataloader"]["drop_last"])
    sampler_test = sampler.BatchSampler(
        sampler.RandomSampler(dataset_test),
        batch_size=cfg.training["dataloader"]["batch_size"],
        drop_last=cfg.training["dataloader"]["drop_last"])
    dataloader_train = DataLoader(dataset=dataset_train,
                            num_workers=cfg.training["dataloader"]["num_workers"],
                            sampler=sampler_train)
    dataloader_test = DataLoader(dataset=dataset_test,
                            num_workers=cfg.training["dataloader"]["num_workers"],
                            sampler=sampler_test)
    return dataloader_train, dataloader_test
