import argparse
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from .config.base import Grid, Config
from .evaluation.Experiments import runExperiment
from .evaluation.Kvariants_Eval import KVariantEval
from torch.utils.data import Sampler, Dataset


class CustomDataset(Dataset):
    def __init__(self,
                 X,
                 y):
        self.data = X
        self.targets = y

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]), (self.targets[idx])

class NeuTraLAD:
    def __init__(self, dataset_name, train_x, test_x, test_y, k, tau, hidden_dims, enc_hdim, trans_hdim):

        config_file = './NeuTraLAD/config_files/config_thyroid.yml'

        model_configurations = Grid(config_file, dataset_name)

        model_configurations[0]['num_trans'] = k
        model_configurations[0]['loss_temp'] = tau
        model_configurations[0]['latent_dim'] = hidden_dims
        model_configurations[0]['enc_hdim'] = enc_hdim
        model_configurations[0]['trans_hdim'] = trans_hdim

        model_configuration = Config(**model_configurations[0])
        dataset = model_configuration.dataset

        result_folder = model_configuration.result_folder + model_configuration.exp_name

        self.risk_assesser = KVariantEval(dataset, result_folder, model_configurations)

        self.train_x = train_x
        self.test_x = test_x
        self.test_y = test_y

        train_dataset = CustomDataset(train_x, np.zeros(train_x.shape[0]))
        test_dataset = CustomDataset(test_x, test_y)

        self.train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True,
                                       num_workers=0)
        self.test_loader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=False,
                                      num_workers=0)

    def fit(self):
        exp_class = runExperiment
        n_dim = self.train_x.shape[1]
        self.risk_assesser._risk_assessment_helper(
            self.train_loader, self.test_loader, n_dim,
            cls=0, cls_type='normal',
            experiment_class=exp_class, exp_path='./NeuTraLAD-exp')

    def decision_function(self, test_x):
        test_set = torch.utils.data.TensorDataset(torch.Tensor(test_x), torch.zeros(test_x.shape[0]))
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=1024, shuffle=False, num_workers=0)
        return self.risk_assesser.get_score(test_loader)
