import os
import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from collections import OrderedDict

from .trainer.PLAD_trainer_fmnist import PLADTrainer
from .VAE_fmnist import VAE
import itertools
import scipy.io
from .networks import mlp



class CustomDataset(Dataset):
    def __init__(self,
                 X,
                 y):
        self.data = X
        self.targets = y

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]), (self.targets[idx])


class PLAD:
    def __init__(self, train_x, test_x, test_y, lamb, weight_decay, hidden_dims1, hidden_dims2,
                 device='cuda:0'):
        self.train_x, self.test_x, self.test_y, self.lamb, self.weight_decay, self.hidden_dims1, self.hidden_dims2 = \
            train_x, test_x, test_y, lamb, weight_decay, hidden_dims1, hidden_dims2

        self.device = device

        train_dataset = CustomDataset(train_x, np.zeros(train_x.shape[0]))
        test_dataset = CustomDataset(test_x, test_y)

        self.train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True,
                                       num_workers=0)
        self.test_loader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=False,
                                      num_workers=0)

        n_dim = train_x.shape[-1]

        self.model = mlp.VanilaMLP(input_dim=n_dim, hidden_dim=hidden_dims1).to(device)
        self.e_ae = VAE(input_dim=n_dim, h_dim=hidden_dims2, z_dim=hidden_dims2).to(device)
        self.optimizer = optim.Adam(itertools.chain(self.model.parameters(), self.e_ae.parameters()), lr=0.001,
                                    weight_decay=weight_decay,
                                    amsgrad=True)
        self.trainer = PLADTrainer(self.model, self.e_ae, self.optimizer, self.lamb, device)

    def fit(self):
        self.trainer.train(self.train_loader, total_epochs=200)

    def decision_function(self, test_x):
        self.model = self.model.to(self.device)
        test_set = torch.utils.data.TensorDataset(torch.Tensor(test_x), torch.zeros(test_x.shape[0]))
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=1024, shuffle=False, num_workers=0)
        score = self.trainer.test(test_loader)
        return score
