import torch
from models import networks
import itertools
import numpy as np
from models.base_model import BaseModel
from util.image_pool import ImagePool
from util.radam import RAdam

class GAN3D(BaseModel):
    """Cycle GAN in 3D"""
    def __init__(self, config):
        BaseModel.__init__(self, config)
        
        self.lambda_idt = config['lambda_identity_A']
        self.lambda_pixel = config['lambda_pixel']
        self.lambda_gp = config['lambda_gp']
        self.direction = config['direction']
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D', 'G', 'G_Pixel', 'idt']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B']
        visual_names_B = ['real_B']
        if self.isTrain and self.lambda_idt > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_B.append('idt')
            self.loss_names.append('idt')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load Gs
            self.model_names = ['G']
        self.inChannels = [3,1]
        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G (G), D (D)
        self.netG = networks.define_G(3, 1, config['network_G'], config['norm_layer'], config['num_downs'], config['ngf'], config['upmode'], use_sigmoid=config['use_sigmoid_G'], gpu_ids=self.gpu_ids)
        if self.isTrain:  # define discriminators
            self.netD = networks.define_D(1, config['network_D'], ndf=config['ndf'], n_layers_D=config['n_layer_D'], norm_layer=config['norm_layer'], gpu_ids=self.gpu_ids)
        if self.isTrain:
            self.RaSGAN = config['RaSGAN']
            self.fake_B_pool = ImagePool(config['pool_size'])  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(config['GANLoss']).to(self.device)  # define GAN loss.
            if config['pixel_loss'] == 'L1':
                self.criterionPixel = torch.nn.L1Loss()
                self.criterionIdt = torch.nn.L1Loss()
            elif config['pixel_loss'] == 'MSE':
                self.criterionPixel = torch.nn.MSELoss()
                self.criterionIdt = torch.nn.MSELoss()
            elif config['pixel_loss'] == 'BCE':
                self.criterionPixel = torch.nn.BCELoss()
                self.criterionIdt = torch.nn.BCELoss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=config['learning_rate'], betas=config['betas'])
            # self.optimizer_G = RAdam(self.netG.parameters(), lr=config['learning_rate'], betas=config['betas'])
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=config['learning_rate'], betas=config['betas'])
            # self.optimizer_D = RAdam(self.netD.parameters(), lr=config['learning_rate'], betas=config['betas'])
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            self.setup(config)

    def set_input(self, input):
        """Unpack input data from the dataloader.

        Parameters:
            input (dict): include the data itself and its metadata information.

        """
        AtoB = self.direction == 'AtoB'
        self.real_A = input[1 if AtoB else 0].to(self.device)
        self.real_B = input[0 if AtoB else 1].to(self.device)

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G_A(A)
    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Get predictions
        pred_real = netD(real)
        pred_fake = netD(fake.detach())
        if self.RaSGAN:
            self.loss_D_real = self.criterionGAN(pred_real - torch.mean(pred_fake), True)
            self.loss_D_fake = self.criterionGAN(pred_fake - torch.mean(pred_real), False)
        else:
            self.loss_D_fake = self.criterionGAN(pred_fake, False)
            self.loss_D_real = self.criterionGAN(pred_real, True)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        loss_D_fake = self.criterionGAN(pred_fake, False)
        gradient_penalty, gradients = networks.cal_gradient_penalty(netD, real, fake, self.device, lambda_gp=self.lambda_gp)
        if self.lambda_gp > 0:
            gradient_penalty.backward(retain_graph=True)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D
    def backward_D(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D = self.backward_D_basic(self.netD, self.real_B, fake_B)
    def backward_G(self):
        """Calculate the loss for generators G_A"""
        # Identity loss
        if self.lambda_idt > 0 :
            # G should be identity if real_B is fed: ||G(B) - B||
            self.idt = self.netG_A(torch.cat((self.real_B,self.real_B,self.real_B),1))
            self.loss_idt = self.criterionIdt(self.idt, self.real_B) * self.lambda_idt
        else:
            self.loss_idt = 0
        
        if self.RaSGAN:
            # Get predictions
            pred_real = self.netD(self.real_B)
            pred_fake = self.netD(self.fake_B)
            # GAN loss D(G(A)))
            self.loss_G = (self.criterionGAN(pred_real - torch.mean(pred_fake), False) + self.criterionGAN(pred_fake - torch.mean(pred_real), True))/2
        else:
            # GAN loss D(G(A))
            self.loss_G = self.criterionGAN(self.netD(self.fake_B), True)
        self.loss_G_Pixel = self.criterionPixel(self.fake_B, self.real_B) * self.lambda_pixel
        # combined loss and calculate gradients
        self.loss_G = self.loss_G + self.loss_idt + self.loss_G_Pixel
        
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute fake images.
        # G
        self.set_requires_grad(self.netD, False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G gradients to zero
        self.backward_G()             # calculate gradients for G
        self.optimizer_G.step()       # update G weights
        # D
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()   # set D gradients to zero
        self.backward_D()      # calculate gradients for D
        self.optimizer_D.step()  # update D weights
