import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset

from .train import SCADTrainer



class CustomDataset(Dataset):
    def __init__(self,
                 X,
                 y):
        self.data = X
        self.targets = y

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]), (self.targets[idx])


class SCAD:
    def __init__(self, train_x, test_x, test_y, k, tau, weight_decay, hidden_dims,
                 device='cuda:0'):
        self.train_x, self.test_x, self.test_y, self.k, self.tau, self.weight_decay, self.hidden_dims = \
            train_x, test_x, test_y, k, tau, weight_decay, hidden_dims

        self.device = device

        train_dataset = CustomDataset(train_x, np.zeros(train_x.shape[0]))
        test_dataset = CustomDataset(test_x, test_y)

        self.train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True,
                                       num_workers=0)
        self.test_loader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=False,
                                      num_workers=0)

        self.n_dim = train_x.shape[-1]

        self.trainer = SCADTrainer(self.n_dim, tau, k, weight_decay, hidden_dims, batch_size=256, device=self.device)

    def fit(self):
        self.trainer.train(self.n_dim, self.train_loader)

    def decision_function(self, test_x):
        test_set = torch.utils.data.TensorDataset(torch.Tensor(test_x), torch.zeros(test_x.shape[0]))
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=1024, shuffle=False, num_workers=0)
        score = self.trainer.test(test_loader)
        return score


