import torch
import torch.nn as nn
import torch.nn.functional as F
from model.network import *
from utils import *

class mclgan(object):
    def name(self):
        return 'mclgan'

    def __init__(self, device, args):
        self.device = device
        self.n_disc = args.n_disc
        self.n_expert = args.n_expert
        self.gan_type = args.gan_type
        self.g_batch = args.g_batch_size
        self.d_batch = args.d_batch_size
        self.fixed_z_batch = args.fixed_z_batch_size
        self.nz = args.nz
        self.d_lr = args.d_learning_rate
        self.g_lr = args.g_learning_rate
        self.lr_gamma = args.lr_gamma
        self.d_wd = args.d_weight_decay
        self.g_wd = args.g_weight_decay
        self.beta1 = args.beta1
        self.beta2 = args.beta2
        self.g_lambda_kld = args.g_lambda_kld
        self.d_lambda_kld = args.d_lambda_kld
        self.lambda_ne = args.lambda_ne
        self.temperature = args.temperature
        self.nonexpert_label = args.nonexpert_label
        self.lambda_l1 = args.lambda_l1
        self.args = args
        self.initialize()

    def initialize(self):

        self.optimizers = []
        # Generators
        self.netG = Generator(self.args).to(self.device)
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.g_lr, betas=(self.beta1, self.beta2), weight_decay=self.g_wd)
        self.optimizers.append(self.optimizer_G)
        
        # Discriminator
        self.netD = Discriminator(self.args).to(self.device)
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.d_lr, betas=(self.beta1, self.beta2), weight_decay=self.d_wd)
        self.optimizers.append(self.optimizer_D)        
        if self.gan_type == 'dcgan':
            self.criterion_adv = nn.BCEWithLogitsLoss(reduction='mean')
        elif self.gan_type == 'lsgan':
            self.criterion_adv = nn.MSELoss(reduction='mean')

        if self.args.z_prior == 'n':
            self.fixed_z = torch.randn(self.fixed_z_batch, self.nz).to(self.device)    
        else:
            self.fixed_z = torch.rand(self.fixed_z_batch, self.nz).to(self.device) * 2.0 - 1.0
         
        self.hist_real = np.zeros(self.n_disc)
        self.hist_fake = np.zeros(self.n_disc)

        self.d_real_label = torch.ones(self.d_batch * self.n_disc).to(self.device)
        self.d_fake_label = torch.zeros(self.d_batch * (self.n_disc - self.n_expert)).to(self.device)
        self.g_real_label = torch.ones(self.g_batch * self.n_expert).to(self.device)
        self.g_fake_label = torch.zeros(self.g_batch * self.n_disc).to(self.device)
        self.p_uniform = torch.ones(1, self.n_disc).to(self.device) / self.n_disc 
        self.epoch = 0

    def feed_data(self, data, label):
        self.real = data.to(self.device)
        if data.size(0) is not self.d_batch:
            self.d_batch = data.size(0)
        
        if self.args.z_prior == 'n':
            self.z = torch.randn(self.g_batch, self.nz).to(self.device)
        else:
            self.z = torch.rand(self.g_batch, self.nz).to(self.device) * 2.0 - 1.0
            
    def train(self):
        self.netG.train()
        self.netD.train()
    
    def eval(self):
        self.netG.eval()
        self.netD.eval()

    def forward_G(self):
        self.fake = self.netG(self.z)
        
    def forward_D(self, update_D):
        if update_D:
            self.pred_real = self.netD(self.real)
            self.pred_fake = self.netD(self.fake.detach()) 
        else:
            self.pred_fake = self.netD(self.fake)   
         
    def backward_G(self):

        expert_loss = torch.zeros(1).to(self.device)
        nonexpert_loss = torch.zeros(1).to(self.device)
        kld_loss = torch.zeros(1).to(self.device)

        self.pred_fake = self.pred_fake.view(self.n_disc, -1).permute(1,0).contiguous()
        p_pred_fake = self.pred_fake
        self.pred_fake, sorted_idx = torch.sort(self.pred_fake, dim=1, descending=True)

        if self.gan_type == 'dcgan' or self.gan_type == 'lsgan':    
            expert_loss = self.criterion_adv(self.pred_fake[:,:self.n_expert].contiguous().view(-1), self.g_real_label)
        elif self.gan_type == 'hinge':
            expert_loss = -self.pred_fake[:,:self.n_expert].mean()

        if self.lambda_ne > 0.0 and self.n_expert < self.n_disc:
            if self.gan_type == 'dcgan' or self.gan_type == 'lsgan':  
                nonexpert_loss = self.criterion_adv(self.pred_fake[:, self.n_expert:].contiguous().view(-1), self.g_fake_label[:self.g_batch * (self.n_disc - self.n_expert)] + self.nonexpert_label)       
            elif self.gan_type == 'hinge':
                nonexpert_loss = -self.pred_fake[:, self.n_expert:].mean() * self.nonexpert_label
        if self.g_lambda_kld > 0.0:
            self.p_gen = softmax_with_temperature(p_pred_fake, self.temperature).mean(0).unsqueeze(0)
            kld_loss = F.kl_div(self.p_gen.log(), self.p_disc.detach(), reduction='batchmean') 
        
        loss = expert_loss + self.lambda_ne * nonexpert_loss + self.g_lambda_kld * kld_loss
    
        self.expert_id = sorted_idx[:,:self.n_expert].contiguous().view(-1)
        self.accum_gen_stats(is_real=False)

        self.loss_G = (loss.item(), expert_loss.item(), nonexpert_loss.item(), kld_loss.item())

        loss.backward(retain_graph=False)


    def backward_D(self):
        real_loss = torch.zeros(1).to(self.device)
        x_fake_loss = torch.zeros(1).to(self.device)
        g_fake_loss = torch.zeros(1).to(self.device)
        kld_loss = torch.zeros(1).to(self.device)
        l1_loss = torch.zeros(1).to(self.device)

        self.pred_real = self.pred_real.view(self.n_disc, -1).permute(1,0).contiguous()
        p_pred_real = self.pred_real
        
        self.pred_real, sorted_idx = torch.sort(self.pred_real, dim=1, descending=True)
        if self.gan_type == 'dcgan' or self.gan_type == 'lsgan':
            real_loss = self.criterion_adv(self.pred_real[:,:self.n_expert].contiguous().view(-1), self.d_real_label[:self.d_batch * self.n_expert])
            g_fake_loss = self.criterion_adv(self.pred_fake, self.g_fake_label)
        elif self.gan_type == 'hinge':
            real_loss = (torch.nn.ReLU()(1.0 - self.pred_real[:, :self.n_expert])).mean()
            g_fake_loss = (torch.nn.ReLU()(1 + self.pred_fake)).mean()

        # add nonexpert loss to nonspecialized discriminators for real sample
        if self.lambda_ne > 0.0 and self.n_expert < self.n_disc:
            if self.gan_type == 'dcgan' or self.gan_type == 'lsgan':
                x_fake_loss = self.criterion_adv(self.pred_real[:, self.n_expert:].contiguous().view(-1), self.d_fake_label + self.nonexpert_label)
            elif self.gan_type == 'hinge':                   
                x_fake_loss = (torch.nn.ReLU()(1.0 - self.pred_real[:, self.n_expert:])).mean() * self.nonexpert_label
                    
        if self.d_lambda_kld > 0.0:
                    
            self.p_disc = softmax_with_temperature(p_pred_real, self.temperature).mean(0).unsqueeze(0)
            if self.lambda_l1 > 0.0:
                l1_loss = torch.norm(p_pred_real, p=1)
            kld_loss = F.kl_div(self.p_disc.log(), self.p_uniform, reduction='batchmean')
            
        self.expert_id = sorted_idx[:,:self.n_expert].contiguous().view(-1)
        self.accum_gen_stats(is_real=True)

        loss = real_loss + self.lambda_ne * x_fake_loss + self.d_lambda_kld * kld_loss + g_fake_loss + self.lambda_l1 * l1_loss
        loss.backward(retain_graph=False)
        self.loss_D = (loss.item(), real_loss.item(), x_fake_loss.item(), g_fake_loss.item(), kld_loss.item())
        
    def accum_gen_stats(self, is_real=True):
        if is_real:
            for i in range(self.expert_id.size(0)):
                self.hist_real[self.expert_id[i]] += 1
        else:
            for i in range(self.expert_id.size(0)):
                self.hist_fake[self.expert_id[i]] += 1


    def optimize_parameters(self):
        # update discriminator
        self.forward_G()
        self.netD.zero_grad()
        self.netG.zero_grad()
        self.forward_D(update_D=True)        
        self.backward_D()
        self.optimizer_D.step()
        
        # update generator
        self.netG.zero_grad()
        self.netD.zero_grad()
        self.forward_D(update_D=False)
        self.backward_G()
        self.optimizer_G.step()
        
      
    def update_learning_rate(self):
        if self.lr_gamma > 0.0:
            for optimizer in self.optimizers:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr'] * self.lr_gamma
    
    def update_kld(self):
        self.d_lambda_kld = self.d_lambda_kld * self.args.kld_decay
    
    def update_epoch(self):
        self.epoch += 1
        if self.epoch % 10 == 1:
            self.hist_real = np.zeros(self.n_disc)
            self.hist_fake = np.zeros(self.n_disc)

    def save_model(self, epoch, save_dir):   
        torch.save(self.netD.state_dict(), '{}/D_{}.ckpt'.format(save_dir, epoch))
        torch.save(self.netG.state_dict(), '{}/G_{}.ckpt'.format(save_dir, epoch))

    def get_visuals(self):
        samples = self.netG(self.fixed_z)
        return samples

    def get_loss(self):
        return self.loss_D, self.loss_G
