import torch
import torch.nn as nn
import torch.optim as optim
import time
import torch.nn.functional as F
import numpy as np
import random

class INIT(object):
    def __init__(self, args):
        print('\n===> Initialization')
        self.args = args
        self.trainloader, self.testloader = self.args.Loader.trainloader, self.args.Loader.testloader
        self.lentrain, self.lentest = self.args.Loader.lentrain, self.args.Loader.lentest
        self.criterion = nn.CrossEntropyLoss().to(self.args.device)
        self.nll = nn.NLLLoss().to(self.args.device)
        self.softmax = nn.Softmax(dim=1).to(self.args.device)
        self.mode = args.mode
        self.div = args.div
        self.obj_fcn_corr = args.obj_fcn_corr
        self.posterior_corr = args.posterior_corr
        if self.mode == "f-PML" and self.obj_fcn_corr:
            print("Performing objective function correction")
            self.transition = None
            self.Floss = FLossBias(args.div, num_classes=args.n_class, T=self.transition, noise_type=self.args.noise_type)
        elif self.mode == "f-PML":
            self.Floss = FLoss(args.div, num_classes=args.n_class, T=None)

        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.args.network.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=1e-3)
            self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=args.total_epochs, eta_min=0.0002)
            print("lr_scheduler: ", self.lr_scheduler)
        else:
            self.optimizer = optim.Adam(self.args.network.parameters(), lr=self.args.lr)

        self.proba = torch.zeros(self.lentrain, self.args.n_class)
        self.loss_train, self.train_class_acc, self.train_label_acc, self.test_acc = [], [], [], []

        self.time = time.time()

    def update_model(self):
        self.args.network.train()
        epoch_loss, epoch_class_accuracy, epoch_label_accuracy = 0, 0, 0
        for index, images, classes, labels in self.trainloader:
            images = images.to(self.args.device)
            labels = labels.to(self.args.device)
            outputs = self.args.network(images)
            # loss
            loss = self.criterion(outputs, labels)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            epoch_loss += loss.item() * len(labels)
            # accuracy
            _, model_label = torch.max(outputs.detach(), dim=1)
            epoch_class_accuracy += (classes == model_label.cpu()).sum().item()
            epoch_label_accuracy += (labels == model_label).cpu().sum().item()

            self.proba[index] = self.softmax(outputs).detach().cpu()

        time_elapse = time.time() - self.time
        return epoch_loss, epoch_class_accuracy, epoch_label_accuracy, time_elapse

    def _estimator_correction(self, output):
        softmax_out = torch.nn.Softmax(dim=-1)
        output = softmax_out(output)
        e_js = self.transition[0]
        e_js[0] = self.transition[1][0]
        e_js = torch.tensor(e_js).to(output.device)
        output = output - e_js
        return output

    def evaluate_model(self):
        self.args.network.eval()
        epoch_accuracy = 0
        for _, images, classes, _ in self.testloader:
            images = images.to(self.args.device)
            classes = classes.to(self.args.device)
            outputs = self.args.network(images)
            # accuracy
            if self.mode == "f-PML" and self.posterior_corr:
                outputs = self._estimator_correction(outputs)
            _, model_label = torch.max(outputs.detach(), dim=1)
            epoch_accuracy += (classes == model_label).cpu().sum().item()

        time_elapse = time.time() - self.time
        return epoch_accuracy, time_elapse

    def report_result(self, epoch_loss, epoch_class_acc, epoch_label_acc, epoch_test_acc):
        epoch_loss /= self.lentrain
        epoch_class_acc /= self.lentrain
        epoch_label_acc /= self.lentrain
        epoch_test_acc /= self.lentest

        self.loss_train.append(epoch_loss)
        self.train_class_acc.append(epoch_class_acc)
        self.train_label_acc.append(epoch_label_acc)
        self.test_acc.append(epoch_test_acc)

        print('Train', epoch_loss, epoch_class_acc, epoch_label_acc)
        print('Test', epoch_test_acc)

        return


    def generate_transition_matrix(self):
        """Generate true transition matrix for True Forward"""
        if self.args.noise_type == 'sym':
            n_rate = 1-float(self.args.noisy_ratio)*(self.args.n_class/(self.args.n_class-1))
            self.transition = n_rate*torch.eye(self.args.n_class)+\
                              (float(self.args.noisy_ratio)/(self.args.n_class-1))*torch.ones(self.args.n_class, self.args.n_class)
        elif self.args.noise_type == 'asym':
            if self.args.dataset == "CIFAR10":
                self.transition = torch.tensor(self.args.true_T)
            elif self.args.dataset == "CIFAR100":
                self.transition = torch.tensor(self.args.true_T)
        else: # clean
            self.transition = torch.eye(self.args.n_class)

        self.transition = self.transition.to(self.args.device)
        return

    def sample_batch(self, outputs, labels):
        f_prob = self.softmax(outputs.detach())
        true_prob = torch.gather(f_prob, 1, labels.reshape(-1, 1)).squeeze()  # P(y|x)
        noisy_prob = torch.sum(self.transition[labels] * f_prob, dim=1)  # P(y\tilde|x)=TP(y|x)
        weighting = true_prob / (noisy_prob + 1e-12)
        smpl_idx = torch.multinomial(weighting,num_samples=int(self.args.N*len(labels)), replacement=True)

        return outputs[smpl_idx], labels[smpl_idx]

    def update_model_sir(self):
        self.args.network.train()
        epoch_loss, epoch_class_accuracy, epoch_label_accuracy = 0, 0, 0
        for index, images, classes, labels in self.trainloader:
            images = images.to(self.args.device)
            labels = labels.to(self.args.device)
            outputs = self.args.network(images)
            # loss
            s_preds, s_lbls = self.sample_batch(outputs, labels)
            loss = self.criterion(s_preds,s_lbls)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            epoch_loss += loss.item() * len(labels)
            # accuracy
            _, model_label = torch.max(outputs, dim=1)
            epoch_class_accuracy += (classes == model_label.cpu()).sum().item()
            epoch_label_accuracy += (labels == model_label).cpu().sum().item()

        time_elapse = time.time() - self.time

        return epoch_loss, epoch_class_accuracy, epoch_label_accuracy, time_elapse


    def update_model_fpml(self):
        print("f-PML")
        self.args.network.train()
        epoch_loss, epoch_class_accuracy, epoch_label_accuracy = 0, 0, 0
        if self.obj_fcn_corr:
            self.Floss.update_T(self.transition)
        for index, images, classes, labels in self.trainloader:
            images = images.to(self.args.device)
            labels = labels.to(self.args.device)
            outputs = self.args.network(images)
            # loss
            loss = self.Floss(outputs, labels)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            epoch_loss += loss.item() * len(labels)
            # accuracy
            _, model_label = torch.max(outputs.detach(), dim=1)
            epoch_class_accuracy += (classes == model_label.cpu()).sum().item()
            epoch_label_accuracy += (labels == model_label).cpu().sum().item()

            self.proba[index] = self.softmax(outputs).detach().cpu()

        if self.lr_scheduler is not None:
            print("Updating the scheduler...")
            self.lr_scheduler.step()
        time_elapse = time.time() - self.time
        return epoch_loss, epoch_class_accuracy, epoch_label_accuracy, time_elapse


class ProbLossStable(nn.Module):
    def __init__(self, reduction='none', eps=1e-5):
        super(ProbLossStable, self).__init__()
        self._name = "Prob Loss"
        self._eps = eps
        self._softmax = nn.Softmax(dim=-1)
        self._nllloss = nn.NLLLoss(reduction='none')

    def forward(self, outputs, labels):
        return self._nllloss(self._softmax(outputs), labels)

class ProbLossStable2Term(nn.Module):
    def __init__(self, reduction='none', eps=1e-5):
        super(ProbLossStable2Term, self).__init__()
        self._name = "Prob Loss"
        self._eps = eps
        self._softmax = nn.Softmax(dim=-1)

    def forward(self, outputs):
        return self._softmax(outputs)


class FLoss(nn.Module):
    def __init__(self, div, num_classes=10, T=None):
        super(FLoss, self).__init__()
        self.div = div
        self.num_classes = num_classes
        self.USE_CUDA = torch.cuda.is_available()
        self.criterion_prob = ProbLossStable()
        self.loss_fn_2 = nn.BCELoss(reduction='none')
        self.loss_fn_3 = nn.CrossEntropyLoss()

        if div == 'KL':
            self.activation = lambda x: -torch.mean(x)
            self.conjugate = lambda x: -torch.mean(torch.exp(x - 1.))
        elif div == 'Jensen-Shannon':
            self.activation = lambda x: -torch.mean(- torch.log(1. + torch.exp(-x))) - torch.log(torch.tensor(2.))
            self.conjugate = lambda x: -torch.mean(x + torch.log(1. + torch.exp(-x))) + torch.log(torch.tensor(2.))
        elif div == 'SL':
            self.activation = lambda x: -torch.mean(1 / (1 + torch.exp(-x)) - 1)
            self.conjugate = lambda x: -torch.mean(x + torch.log(1 + torch.exp(-x) + torch.exp(-x)/(1 + torch.exp(-x))))
        else:
            raise NotImplementedError("[-] Not Implemented f-divergence %s" % div)

    def forward(self, outputs, label):
        prob_reg = -self.criterion_prob(outputs, label.long())
        loss_1 = self.activation(prob_reg)
        loss_2 = self.conjugate(self.criterion_prob2(outputs))
        loss = loss_1 - loss_2
        return loss


class FLossBias(nn.Module):
    def __init__(self, div, num_classes=10, T=None, noise_type="asym"):
        super(FLossBias, self).__init__()
        self.div = div
        self.num_classes = num_classes
        self.USE_CUDA = torch.cuda.is_available()
        self.criterion_prob = ProbLossStable()
        self.criterion_prob2 = ProbLossStable2Term()
        self.loss_fn_2 = nn.BCELoss(reduction='none')
        self.loss_fn_3 = nn.CrossEntropyLoss()
        self._nllloss = nn.NLLLoss(reduction='none')
        self.noise_type = noise_type
        if div == 'KL':
            self.activation = lambda x: -torch.mean(x)
            self.conjugate = lambda x: -torch.mean(torch.exp(x - 1.))
        elif div == 'Jensen-Shannon':
            self.activation = lambda x: -torch.mean(- torch.log(1. + torch.exp(-x))) - torch.log(torch.tensor(2.))
            self.conjugate = lambda x: -torch.mean(x + torch.log(1. + torch.exp(-x))) + torch.log(torch.tensor(2.))
        elif div == 'SL':
            self.activation = lambda x: -torch.mean(1 / (1 + torch.exp(-x)) - 1)
            self.conjugate = lambda x: -torch.mean(x + torch.log(1 + torch.exp(-x) + torch.exp(-x)/(1 + torch.exp(-x))))
        else:
            raise NotImplementedError("[-] Not Implemented f-divergence %s" % div)

    def update_T(self, T):
        self.T = torch.tensor(T)

    def forward(self, outputs, label):
        prob_reg = -self.criterion_prob(outputs, label.long())
        loss_1 = self.activation(prob_reg)
        loss_2 = self.conjugate(self.criterion_prob2(outputs))
        loss = loss_1 - loss_2
        if self.noise_type == "asym":
            bias = self.bias_estimate_asym(outputs, label)
        else:
            raise NotImplementedError("[-] Not Implemented noise model %s" % self.noise_type)
        return loss - bias

    def bias_estimate_asym(self, output, label):
        bias = 0
        e_js = self.T[0]
        e_js[0] = self.T[1][0]
        e_js = torch.tensor(e_js)
        sum_e = torch.sum(e_js)
        for k in range(self.num_classes):
            label_tmp = label * 0. + k
            first_term = e_js[k] * self.activation(-self.criterion_prob(output, label_tmp.long()))
            second_term = sum_e * self.conjugate(-self.criterion_prob(output, label_tmp.long()))
            bias += first_term - second_term
        return bias

