import argparse
import numpy as np
import math
import torch.autograd as autograd
# import .h
from .mlp import Recognizer_mlp
# from mlp_cifar import Recognizer_mlp_cifar
import torch
import torch.nn as nn
import torch.nn.functional as F
# import pdb
from sklearn.metrics import auc, roc_auc_score
from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    def __init__(self,
                 X,
                 y):
        self.data = X
        self.targets = y

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]), (self.targets[idx])

class HRN:
    def __init__(self, train_x, test_x, test_y, latent_dim, lamb, n, weight_decay, device='cuda'):

        self.train_x, self.test_x, self.test_y, self.latent_dim, self.lamb, self.n, self.weight_decay = \
            train_x, test_x, test_y, latent_dim, lamb, n, weight_decay

        self.device = device

        train_dataset = CustomDataset(train_x, np.zeros(train_x.shape[0]))
        test_dataset = CustomDataset(test_x, test_y)

        self.train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True,
                                       num_workers=0)
        self.test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False,
                                      num_workers=0)

        self.net = Recognizer_mlp(n_dim=train_x.shape[1], latent=latent_dim)

        self.optimizer = torch.optim.SGD(
            self.net.parameters(), lr=0.1, momentum=0.9, weight_decay=weight_decay)

    def fit(self):
        self.net = self.net.to(self.device)
        self.net.train()
        for epoch in range(100):
            for inputs, _ in self.train_loader:
                inputs = inputs.to(self.device)
                loss, finished_epoch, loss1 = self.calculate_loss(inputs.float())

    def decision_function(self, test_x):

        model = self.net

        model = model.to(self.device)
        test_set = torch.utils.data.TensorDataset(torch.Tensor(test_x), torch.zeros(test_x.shape[0]))
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=1024, shuffle=False, num_workers=0)
        model.eval()

        score_all = []
        Y_all = []
        with torch.no_grad():
            for test_imgs, test_imgs_label in test_loader:
                test_imgs = test_imgs.to(self.device)
                test_imgs_label = test_imgs_label.to(self.device)

                score_temp = self.net(test_imgs)
                score_temp = torch.sigmoid(score_temp)
                score_all.append(score_temp.squeeze().to('cpu'))
                Y_all.append(test_imgs_label)

        score_all = torch.cat(tuple(score_all), 0).numpy()
        return score_all

    def calc_gradient_penalty(self, real_data, fake_data):
        BATCH_SIZE = real_data.shape[0]
        if real_data.dim() == 2:
            alpha = torch.rand(BATCH_SIZE, 1)
        else:
            alpha = torch.rand(BATCH_SIZE, 1, 1)

        alpha = alpha.expand(real_data.size())
        alpha = alpha.cuda()

        interpolates = alpha * real_data + ((1 - alpha) * fake_data)
        interpolates = interpolates.cuda()
        interpolates = autograd.Variable(interpolates, requires_grad=True)

        disc_interpolates = self.net(interpolates)
        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                  grad_outputs=torch.ones(
                                      disc_interpolates.size()).cuda(),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradient_penalty = ((gradients.norm(2, dim=1)) ** self.n).mean()

        return gradient_penalty

    def calculate_loss(self, imgs):
        finished_epoch = 0
        loss_pen = 0
        mainloss_p = 0

        loss_pen = loss_pen + self.calc_gradient_penalty(imgs, imgs)
        score_temp_0 = self.net(imgs)
        mainloss_p = mainloss_p + torch.log(torch.sigmoid(1 * score_temp_0) + 1e-2).mean()

        loss = - 1.0 * mainloss_p + self.lamb * loss_pen
        self.optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.0)

        self.optimizer.step()

        return loss_pen.data, finished_epoch, mainloss_p.data
