from functools import reduce
import torch
from torch import nn, autograd
from torch.autograd import Variable
import os
import os.path
from models.generator import gan
import numpy as np
from torch.nn import functional as F
import torchvision

EPSILON = 1e-16
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class WGAN(nn.Module):
    def __init__(self, z_size,
                 image_size,
                 image_channel_size,
                 c_channel_size,
                 g_channel_size,
                 num_classes
                 ):

        super().__init__()

        # loss functions
        self.dis_criterion = nn.BCEWithLogitsLoss()
        self.aux_criterion = nn.NLLLoss()
        self.z_size = z_size
        self.image_size = image_size
        self.image_channel_size = image_channel_size
        self.c_channel_size = c_channel_size
        self.g_channel_size = g_channel_size
        self.ensemble_loss = nn.KLDivLoss(reduction="batchmean")
        self.num_classes = num_classes

        # Builind backbone
        self.generator = gan.Generator(
            z_size=self.z_size,
            image_size=self.image_size,
            image_channel_size=self.image_channel_size,
            channel_size=self.g_channel_size
        )

        self.critic = self.critic = gan.Critic(
                image_size=self.image_size,
                image_channel_size=self.image_channel_size,
                channel_size=self.c_channel_size,
                num_classes=self.num_classes
            )

        # training related components that should be set before training.
        self.generator_optimizer = None
        self.critic_optimizer = None
        self.lamda = None

    def kd_loss(self, logits_a, logits_b):
        return F.mse_loss(logits_a, logits_b)

    def train_a_batch(self, x, y,
                      classes_so_far,
                      x_=None, y_=None,
                      importance_of_new_task=.5,
                      x_g=None, y_g=None,
                      ):

        assert x_ is None or x.size() == x_.size()
        assert y_ is None or y.size() == y_.size()
        c_logits_replay = None

        # =============== update D ==============
        # run the critic on the real data.
        c_loss_real, g_real, c_logits, aux = self._c_loss(x, y, classes_so_far, return_g=True, return_aux=True)

        # ==============
        # 1. AC-GAN loss
        # ==============

        # run the critic on the replayed data.
        if x_ is not None and y_ is not None:
            c_loss_replay, g_replay, c_logits_replay, aux = self._c_loss(x_, y_, classes_so_far, return_g=True,
                                                                         return_aux=True)

            c_loss = (
                    importance_of_new_task * c_loss_real +
                    (1 - importance_of_new_task) * c_loss_replay
            )

        else:
            c_loss = c_loss_real

        c_loss = c_loss
        # updation
        self.critic_optimizer.zero_grad()
        c_loss.backward()
        self.critic_optimizer.step()

        # =============== update G ==============
        self.generator_optimizer.zero_grad()

        # 1. AC-GAN loss:
        g_loss, g_logits = self._g_loss(x, y, classes_so_far)
        g_loss.backward()
        self.generator_optimizer.step()
        return {'c_loss': c_loss.item(), 'g_loss': g_loss.item(), 'aux_f': aux[0].item(), 'aux_r': aux[1].item(),
                'features': aux[2], }

    def sample(self, size, classes_so_far):
        noise, aux_label, _ = self.generate_noise_with_classes(size, classes_so_far)

        fake = self.generator(noise.to(torch.device('cuda:0')))

        return fake, aux_label

    def set_generator_optimizer(self, optimizer):
        self.generator_optimizer = optimizer

    def set_critic_optimizer(self, optimizer):
        self.critic_optimizer = optimizer

    def set_critic_updates_per_generator_update(self, k):
        self.critic_updates_per_generator_update = k

    def set_lambda(self, l):
        self.lamda = l

    def _noise(self, size):
        z = Variable(torch.randn(size, self.z_size)) * .1
        return z.to(torch.device('cuda:0'))

    ############################################################

    def _c_loss(self, x, y, classes_so_far, return_g=False, return_aux=False, return_feature=False):

        # info
        batch_size = x.size(0)

        # generate label:
        dis_label = torch.FloatTensor(batch_size)
        dis_label = dis_label.to(torch.device('cuda:0'))
        y = y.to(torch.device('cuda:0'))
        dis_label = Variable(dis_label)
        dis_label.data.fill_(1)

        # train with real data:
        dis_output, aux_output, logits_real, feature = self.critic(x, if_features=True)
        dis_errD_real = self.dis_criterion(dis_output, dis_label)
        aux_errD_real = self.aux_criterion(torch.log(aux_output), y)
        loss_c_real = dis_errD_real + aux_errD_real

        # train with generated data:
        # generate noise:
        noise, aux_label, dis_label = self.generate_noise_with_classes(batch_size, classes_so_far)

        fake = self.generator(noise.to(torch.device('cuda:0')))

        dis_output, aux_output, logits = self.critic(fake.detach())  # G will not be updated

        dis_errD_fake = self.dis_criterion(dis_output, dis_label)
        aux_errD_fake = self.aux_criterion(torch.log(aux_output), aux_label)

        loss_c_fake = dis_errD_fake + aux_errD_fake

        loss_c_all = loss_c_fake + loss_c_real

        if return_g:
            if return_aux == True:
                return loss_c_all, fake, logits_real, (aux_errD_fake, aux_errD_real, feature)
            else:
                return loss_c_all, fake, logits_real
        else:
            if return_aux == True:
                return loss_c_all, logits_real, (aux_errD_fake, aux_errD_real, feature)
            else:
                return loss_c_all, logits_real

    def generate_noise_with_classes(self, batch_size, classes_so_far, label=None):


        noise = torch.FloatTensor(batch_size, self.z_size, 1, 1)
        # noise = torch.FloatTensor(batch_size, self.z_size)
        dis_label = torch.FloatTensor(batch_size)
        aux_label = torch.LongTensor(batch_size)
        real_label = 1
        fake_label = 0

        dis_label, aux_label = dis_label.to(torch.device('cuda:0')), aux_label.to(torch.device('cuda:0'))

        # train with real data:
        # define variables
        noise = Variable(noise)
        dis_label = Variable(dis_label)
        aux_label = Variable(aux_label)

        # to obtain the noise:
        # why 1,1?????
        noise.data.resize_(batch_size, self.z_size, 1, 1).normal_(0, 1)

        label = np.random.choice(classes_so_far, batch_size) if label is None else label

        noise_ = np.random.normal(0, 1, (batch_size, self.z_size))

        class_onehot = np.zeros((batch_size, self.num_classes))

        class_onehot[np.arange(batch_size), label] = 1

        noise_[np.arange(batch_size), :self.num_classes] = class_onehot[np.arange(batch_size)]

        noise_ = (torch.from_numpy(noise_))
        noise.data.copy_(noise_.view(batch_size, self.z_size, 1, 1))
        #noise.data.copy_(noise_.view_as(noise))
        aux_label.data.resize_(batch_size).copy_(torch.from_numpy(label))

        # generate images with noise & labels
        noise = torch.squeeze(noise)
        dis_label.data.fill_(fake_label)

        return noise, aux_label, dis_label

    def _g_loss(self, x, y, classes_so_far, return_g=False):

        # info
        batch_size = x.size(0)

        # generate noise, aux_label is labeled FALSE
        noise, aux_label, _ = self.generate_noise_with_classes(batch_size=batch_size, classes_so_far=classes_so_far)

        # prepare new dis_label:
        dis_label = torch.FloatTensor(batch_size)
        dis_label = dis_label.to(torch.device('cuda:0'))
        dis_label = Variable(dis_label)
        dis_label.data.fill_(1)

        fake = self.generator(noise.to(torch.device('cuda:0')))

        dis_output, aux_output, logits = self.critic(fake)

        # print('g_loss_dis_output: ' + str(dis_output))
        # print('g_loss_dis_label: ' + str(dis_label))
        # print('g_loss_noise: ' + str(noise))
        # print('g_loss_aux_label: ' + str(aux_label))

        # NLLL loss
        # aux_output = torch.log(aux_output)

        dis_errG = self.dis_criterion(dis_output, dis_label)
        aux_errG = self.aux_criterion(torch.log(aux_output), aux_label)

        loss_g = dis_errG + aux_errG

        # print('g_loss_aux_output: ' + str(aux_output))
        # print('g_loss_aux_errG: ' + str(aux_errG))


        return loss_g, logits

    def visualize(self, sample_size=16, path='./images'):

        os.makedirs(os.path.dirname(path), exist_ok=True)
        data, label = self.sample(sample_size, 1)

        torchvision.utils.save_image(
            data,
            path + '.jpg',
            nrow=6,
        )

        print('image is saved!')

    def train_a_batch_critic_only(self,
                                  x, y,
                                  classes_so_far,
                                  x_=None, y_=None,
                                  importance_of_new_task=.5):

        assert x_ is None or x.size() == x_.size()
        assert y_ is None or y.size() == y_.size()

        # ===================
        # 1. prediction loss
        # ====================

        dis_output, aux_output, logits_real = self.critic(x)
        c_loss_real = self.aux_criterion(torch.log(aux_output), y)

        # run the critic on the replayed data.
        if x_ is not None and y_ is not None:
            dis_output, aux_output, logits_real = self.critic(x_)
            c_loss_replay = self.aux_criterion(torch.log(aux_output, y_))

            c_loss = (importance_of_new_task * c_loss_real + (1 - importance_of_new_task) * c_loss_replay)

        else:
            c_loss = c_loss_real

        # updation
        self.critic_optimizer.zero_grad()
        c_loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 5)
        self.critic_optimizer.step()

        return {'c_loss': c_loss.item()}

    def train_a_batch_lwf(self,
                          current_task,
                          server_generator,
                          last_copy,
                          if_last_copy,
                          x, y,
                          classes_so_far,
                          x_=None, y_=None,
                          importance_of_new_task=.5):

        assert x_ is None or x.size() == x_.size()
        assert y_ is None or y.size() == y_.size()

        T = 2

        # ===================
        # 1. prediction loss
        # ===================

        dis_output, aux_output, logits_real = self.critic(x)
        c_loss_real = self.aux_criterion(torch.log(aux_output), y)

        # =======================
        # 2. kd: server -> client
        # =======================

        if current_task > 0:
            user_logit_logp = torch.log(F.softmax(logits_real / T, dim=1))

            dis_output, aux_output, server_logit = server_generator.critic(x)
            server_logit_p = F.softmax(server_logit / T, dim=1).clone().detach()

            kd_loss_server = self.ensemble_loss(user_logit_logp, server_logit_p)

        else:
            kd_loss_server = 0

        # =================================
        # 3. KD loss : last copy -> client
        # =================================
        if if_last_copy:
            dis_output, aux_output, copy_logit = last_copy.critic(x)
            copy_logit_p = F.softmax(copy_logit / T, dim=1).clone().detach()

            kd_loss_copy = self.ensemble_loss(user_logit_logp, copy_logit_p)

        else:
            kd_loss_copy = 0

        if current_task > 0:
            alpha = 0.33
            beta = 0.33

        else:
            alpha = 1
            beta = 0

        loss_all = alpha * c_loss_real + beta * kd_loss_copy + (1 - alpha - beta) * kd_loss_server

        # updation
        self.critic_optimizer.zero_grad()
        loss_all.backward()
        self.critic_optimizer.step()

        return {'loss_all': loss_all.item()}

    def train_a_batch_all(self, x, y,
                          available_labels,
                          classes_so_far,
                          generator_server,
                          glob_iter_,
                          x_=None, y_=None,
                          importance_of_new_task=.5,
                          x_g=None, y_g=None,
                          ):

        assert x_ is None or x.size() == x_.size()
        assert y_ is None or y.size() == y_.size()
        c_logits_replay = None

        # =============== update D ==============
        # run the critic on the real data.
        c_loss_real, g_real, c_logits, aux = self._c_loss(x, y, classes_so_far, return_g=True, return_aux=True)

        # ==============
        # 1. AC-GAN loss
        # ==============

        # run the critic on the replayed data.
        if x_ is not None and y_ is not None:
            c_loss_replay, g_replay, c_logits_replay, aux = self._c_loss(x_, y_, classes_so_far, return_g=True,
                                                                         return_aux=True)

            c_loss = (
                    importance_of_new_task * c_loss_real +
                    (1 - importance_of_new_task) * c_loss_replay
            )

        else:
            c_loss = c_loss_real

        # ============================
        # 2. kd loss: D
        # ============================
        batch_size = y.size(0)
        if glob_iter_ != 0:

            noise, aux_label, _ = self.generate_noise_with_classes(batch_size, classes_so_far=None,
                                                                   label=y.cpu().detach().numpy())
            fake_server = generator_server.generator(noise.to(torch.device('cuda:0')))

            # client output with fake_server:
            _, p, logits = self.critic(fake_server.detach())
            # client output with real:
            _, p_real, logits_real = self.critic(x)
            # kd loss:
            kd_loss_d = self.ensemble_loss(torch.log(p_real), p)
        else:
            kd_loss_d = 0

        # ============================
        # 3. kd loss: G
        # ============================
        if glob_iter_ != 0:
            # fake_server
            noise, aux_label, _ = self.generate_noise_with_classes(batch_size, classes_so_far=available_labels)
            fake_server = generator_server.generator(noise.to(torch.device('cuda:0')))

            # fake_own:
            noise, aux_label, _ = self.generate_noise_with_classes(batch_size, classes_so_far=None,
                                                                   label=aux_label.cpu().detach().numpy())
            fake_own = self.generator(noise.to(torch.device('cuda:0')))

            # client output with fake_server:
            _, p_server, logits_server = self.critic(fake_server.detach())
            # client output with its own:
            _, p, logits = self.critic(fake_own.detach())
            # kd loss:
            kd_loss_g_1 = self.ensemble_loss(torch.log(p), p_server)
            # classification loss:
            kd_loss_g_2 = self.aux_criterion(torch.log(p_server), aux_label)

            kd_loss_g = kd_loss_g_1 + kd_loss_g_2
        else:
            kd_loss_g = kd_loss_g_1 = kd_loss_g_2 = 0

        c_loss = c_loss + kd_loss_d * 0.2 + kd_loss_g_1 * 0.3 + kd_loss_g_2 * 0.2  # ablation
        # c_loss = c_loss

        # updation
        self.critic_optimizer.zero_grad()
        c_loss.backward()
        self.critic_optimizer.step()

        # =============== update G ==============
        self.generator_optimizer.zero_grad()

        # 1. AC-GAN loss:
        g_loss, g_logits = self._g_loss(x, y, classes_so_far)

        g_loss.backward()
        self.generator_optimizer.step()

        return {'features': aux[2], 'aux_f': aux[0].item(), 'aux_r': aux[1].item(), 'c_loss': c_loss.item(),
                'g_loss': g_loss.item(), 'kd_loss_d': kd_loss_d.item() if glob_iter_ != 0 else 0,
                'kd_loss_g': kd_loss_g.item() if glob_iter_ != 0 else 0}