import torch
from .base_model import BaseModel
import itertools
from . import networks
import torch.nn.functional as F
from pytorch_wavelets import DWTForward, DWTInverse
from util.kd_utils import *
import torchvision

class Pix2PixModel(BaseModel):
    """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.

    The model training requires '--dataset_mode aligned' dataset.
    By default, it uses a '--netG unet256' U-Net generator,
    a '--netD basic' discriminator (PatchGAN),
    and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).

    pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        For pix2pix, we do not use image buffer
        The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
        By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
        """
        # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
        parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned')
        if is_train:
            parser.set_defaults(pool_size=0, gan_mode='vanilla')
            parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')

        return parser

    def __init__(self, opt, teacher=None):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        if teacher is not None:
            self.teacher = teacher
        self.opt = opt
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        if opt.distill:
            self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake', 'kd']
            self.visual_names = ['real_A', 'fake_B', 'real_B', "teacher_fakeB"]
        else:
            self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
            self.visual_names = ['real_A', 'fake_B', 'real_B' ,'fake_B2']
        if opt.distill and opt.kd_wavelet_distillation !=0:
            self.xfm = DWTForward(J=3, mode='zero', wave='db3').cuda()

        if opt.distill:
            student_channel_size = [3, self.opt.ngf, self.opt.ngf*2, self.opt.ngf*4, self.opt.ngf*8, self.opt.ngf*8, self.opt.ngf*8, self.opt.ngf*8,
                                    self.opt.ngf*16, self.opt.ngf*16, self.opt.ngf*16, self.opt.ngf*16, self.opt.ngf*8, self.opt.ngf*4, self.opt.ngf*2, self.opt.ngf, 3]
            teacher_channel_size = [3, self.teacher.opt.ngf, self.teacher.opt.ngf*2, self.teacher.opt.ngf*4, self.teacher.opt.ngf*8, self.teacher.opt.ngf*8, self.teacher.opt.ngf*8, self.teacher.opt.ngf*8,
                                    self.teacher.opt.ngf*16, self.teacher.opt.ngf*16, self.teacher.opt.ngf*16, self.teacher.opt.ngf*16, self.teacher.opt.ngf*8, self.teacher.opt.ngf*4, self.teacher.opt.ngf*2, self.teacher.opt.ngf, 3]
        if opt.distill and opt.kd_feature_distillation != 0:
            layer_list = []
            for index in self.opt.choice_of_feature:
                layer_list.append(torch.nn.Conv2d(student_channel_size[index], teacher_channel_size[index], 1))
            self.feature_distillation_adaptation_layer = torch.nn.ModuleList(layer_list).cuda()

        if opt.distill and opt.kd_non_local_distillation:
            student_non_local_list = []
            teacher_non_local_list = []
            non_local_adaptation_list = []
            for index in self.opt.choice_of_feature:
                student_non_local_list.append(NonLocalBlockND(in_channels=student_channel_size[index]))
                teacher_non_local_list.append(NonLocalBlockND(in_channels=teacher_channel_size[index]))
                non_local_adaptation_list.append(torch.nn.Conv2d(student_channel_size[index], teacher_channel_size[index], 1))
                self.student_non_local = nn.ModuleList(student_non_local_list).cuda()
                self.teacher_non_local = nn.ModuleList(teacher_non_local_list).cuda()
                self.non_local_adaptation = nn.ModuleList(non_local_adaptation_list).cuda()

        if opt.distill and opt.kd_channel_attention_distillation:
            layer_list = []
            for index in self.opt.choice_of_feature:
                layer_list.append(torch.nn.Linear(student_channel_size[index], teacher_channel_size[index]))
            self.channel_attention_adaptation = torch.nn.ModuleList(layer_list).cuda()

        if opt.distill and opt.kd_perceptual_distillation:
            self.vgg = torchvision.models.vgg19_bn(pretrained=True).cuda()
            self.vgg.eval()

        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>

        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        if self.isTrain:
            self.model_names = ['G', 'D']
        else:  # during test time, only load G
            self.model_names = ['G']
        # define networks (both generator and discriminator)
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf * 4, 'unet_64', opt.norm,
                                      not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:  # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                          opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD2 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
                                opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)


        if self.isTrain:
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionL2 = torch.nn.MSELoss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            g_param_list = [self.netG.parameters(), self.netG2.parameters()]
            d_param_list = [self.netD.parameters(), self.netD2.parameters()]

            if self.opt.distill and self.opt.kd_feature_distillation:
                g_param_list.append(self.feature_distillation_adaptation_layer.parameters())
            if self.opt.distill and self.opt.kd_channel_attention_distillation:
                g_param_list.append(self.channel_attention_adaptation.parameters())
            if self.opt.distill and self.opt.kd_non_local_distillation:
                g_param_list.append(self.non_local_adaptation.parameters())

            self.optimizer_G = torch.optim.Adam(itertools.chain(*g_param_list), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(*d_param_list), lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap images in domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

        if self.opt.distill:
            AtoB = self.teacher.opt.direction == 'AtoB'
            self.teacher.real_A = input['A' if AtoB else 'B'].to(self.device)
            self.teacher.real_B = input['B' if AtoB else 'A'].to(self.device)
            self.teacher.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG(self.real_A)  # G(A)
        self.fake_B2 = self.netG2(self.fake_B.detach())  # G(G(A))
        if self.evaluate_teacher:
            self.fake_B = self.fake_B2

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        fake_AB2 = torch.cat((self.real_A, self.fake_B2), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
        pred_fake2 = self.netD2(fake_AB2.detach())
        self.loss_D_fake2 = self.criterionGAN(pred_fake2, False)
        # Real
        real_AB2 = torch.cat((self.real_A, self.real_B), 1)
        pred_real2 = self.netD2(real_AB2)
        self.loss_D_real2 = self.criterionGAN(pred_real2, True)
        # combine loss and calculate gradients
        self.loss_D += (self.loss_D_fake2 + self.loss_D_real2) * 0.5
        self.loss_D.backward()


    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        fake_AB2 = torch.cat((self.real_A, self.fake_B2), 1)
        pred_fake2 = self.netD2(fake_AB2)
        self.loss_G_GAN2 = self.criterionGAN(pred_fake2, True)
        # Second, G(A) = B
        self.loss_G_L12 = self.criterionL1(self.fake_B2, self.real_B) * self.opt.lambda_L1
        # combine loss and calculate gradients
        self.loss_G += self.loss_G_GAN2 + self.loss_G_L12 
        self.loss_G += self.opt.kd_self * self.criterionL1(self.fake_B2.detach(), self.fake_B)

        if self.opt.distill:
            self.loss_G += self.loss_kd
            #   self.loss_G += self.loss_feat_kd
            #   self.loss_G += self.loss_spkd
        self.loss_G.backward()

    def optimize_parameters(self):
        self.forward()                   # compute fake images: G(A)
        if self.opt.distill:
            with torch.no_grad():
                self.teacher.forward()
            self.loss_kd = 0.0
            self.teacher_fakeB = self.teacher.fake_B
            student_feature = [self.netG.module.feature_buffer[index] for index in self.opt.choice_of_feature]
            teacher_feature = [self.teacher.netG.module.feature_buffer[index] for index in self.opt.choice_of_feature]
            #   prediction kd
            if self.opt.kd_prediction_distillation != 0:
                self.loss_kd += torch.nn.functional.l1_loss(self.fake_B, self.teacher.fake_B) * self.opt.kd_prediction_distillation

            #   frequency kd
            if self.opt.kd_wavelet_distillation != 0:
                student_l, student_h = self.xfm(self.fake_B)
                teacher_l, teacher_h = self.xfm(self.teacher.fake_B)
                for index in range(len(student_h)):
                    self.loss_kd += torch.nn.functional.l1_loss(teacher_h[index], student_h[index]) * self.opt.kd_wavelet_distillation

            #   feature kd
            if self.opt.kd_feature_distillation != 0:
                for index in range(len(student_feature)):
                    self.loss_kd += torch.nn.functional.mse_loss(
                        self.feature_distillation_adaptation_layer[index](student_feature[index]),
                        teacher_feature[index]) * self.opt.kd_feature_distillation

            #   similarity kd
            if self.opt.kd_similarity_distillation != 0:
                for index in range(len(student_feature)):
                    student_feat = student_feature[index]
                    student_feat = student_feat.view(student_feat.size(0), student_feat.size(1), -1)  # 1 x c x wh
                    teacher_feat = teacher_feature[index]
                    student_feat_transpose = torch.transpose(student_feat, 1, 2)
                    student_relation = F.normalize(torch.bmm(student_feat_transpose, student_feat), dim=2)
                    teacher_feat = teacher_feat.view(teacher_feat.size(0), teacher_feat.size(1), -1)
                    teacher_feat_transpose = torch.transpose(teacher_feat, 1, 2)
                    teacher_relation = F.normalize(torch.bmm(teacher_feat_transpose, teacher_feat), dim=2)
                    self.loss_kd += torch.nn.functional.mse_loss(student_relation, teacher_relation) * self.opt.kd_similarity_distillation

            #   non local kd
            if self.opt.kd_non_local_distillation != 0:
                for index in range(len(student_feature)):
                    student_non_local_results = self.student_non_local[index](student_feature[index])
                    teacher_non_local_results = self.teacher_non_local[index](teacher_feature[index])
                    self.loss_kd += torch.nn.functional.mse_loss(
                        self.non_local_adaptation[index](student_non_local_results),
                        teacher_non_local_results) * self.opt.kd_non_local_distillation

            #   perceptual kd
            if self.opt.kd_perceptual_distillation != 0:
                student_vgg_feature = self.vgg.features(self.fake_B)
                teacher_vgg_feature = self.vgg.features(self.teacher.fake_B)
                self.loss_kd += torch.nn.functional.mse_loss(student_vgg_feature, teacher_vgg_feature) * self.opt.kd_perceptual_distillation

            #   channel attention
            if self.opt.kd_channel_attention_distillation != 0:
                for index in range(len(student_feature)):
                    student_channel_attention = torch.mean(student_feature[index], [2, 3], keepdim=False)   # b x c tensor
                    teacher_channel_attention = torch.mean(teacher_feature[index], [2, 3], keepdim=False)  # b x c tensor
                    self.loss_kd += torch.nn.functional.mse_loss(self.channel_attention_adaptation[index](student_channel_attention),
                                                                 teacher_channel_attention) * self.opt.kd_channel_attention_distillation

            #   spatial attention
            if self.opt.kd_spatial_attention_distillation != 0:
                for index in range(len(student_feature)):
                    student_spatial_attention = torch.mean(student_feature[index], [1], keepdim=False)   # b x 1 x w x h tensor
                    teacher_spatial_attention = torch.mean(teacher_feature[index], [1], keepdim=False)  # b x 1 x w x h tensor
                    self.loss_kd += torch.nn.functional.mse_loss(student_spatial_attention, teacher_spatial_attention) * self.opt.kd_spatial_attention_distillation


        # update D
        self.set_requires_grad(self.netD, True)  # enable backprop for D
        self.set_requires_grad(self.netD2, True)  # enable backprop for D
        self.optimizer_D.zero_grad()     # set D's gradients to zero
        self.backward_D()                # calculate gradients for D
        self.optimizer_D.step()          # update D's weights
        # update G
        self.set_requires_grad(self.netD, False)  # D requires no gradients when optimizing G
        self.set_requires_grad(self.netD2, False)  # D requires no gradients when optimizing G
        self.optimizer_G.zero_grad()        # set G's gradients to zero
        self.backward_G()                   # calculate graidents for G
        self.optimizer_G.step()             # udpate G's weights


'''
self.loss_kd +=  self.criterionL1(self.fake_B, self.teacher.fake_B) * self.opt.kd1
            self.teacher_fakeB = self.teacher.fake_B
            self.loss_feat_kd = 0.0
            self.loss_feat_kd *= self.opt.kd2
            self.loss_spkd = 0.0
            self.loss_non_local = 0.0
            for index in range(7):
                student_feat = self.netG.module.feature_buffer[index]
                teacher_feat = self.teacher.netG.module.feature_buffer[index]
                student_relation = self.student_non_local[index](student_feat)
                teacher_relation = self.teacher_non_local[index](teacher_feat)
                self.loss_non_local += self.criterionL1(self.non_local_adaptation_layers[index](student_relation), teacher_relation)
            self.loss_non_local *= self.opt.kd2
            self.loss_attn_mask = 0.0
            for index in range(7):
                student_feat = self.netG.module.feature_buffer[index]
                teacher_feat = self.teacher.netG.module.feature_buffer[index]
                t_attention_mask = torch.mean(torch.abs(teacher_feat), [1], keepdim=True)
                size = t_attention_mask.size()
                t_attention_mask = t_attention_mask.view(teacher_feat.size(0), -1)
                t_attention_mask = torch.softmax(t_attention_mask, dim=1) * size[-1] * size[-2]
                t_attention_mask = t_attention_mask.view(size)
                s_attention_mask = torch.mean(torch.abs(student_feat), [1], keepdim=True)
                size = s_attention_mask.size()
                s_attention_mask = s_attention_mask.view(student_feat.size(0), -1)
                s_attention_mask = torch.softmax(s_attention_mask, dim=1) * size[-1] * size[-2]
                s_attention_mask = s_attention_mask.view(size)
                sum_attention_mask = (t_attention_mask + s_attention_mask) / 2
                sum_attention_mask = sum_attention_mask.detach()
                self.loss_attn_mask += dist2(teacher_feat, self.adaptation_layers[index](student_feat), attention_mask=sum_attention_mask)

            self.loss_attn_mask *= self.opt.kd3
            self.loss_attn = 0.0
            for index in range(7):
                student_feat = self.netG.module.feature_buffer[index]
                teacher_feat = self.teacher.netG.module.feature_buffer[index]
                student_spatial_attention = torch.mean(student_feat, [1], keepdim=True)
                teacher_spatial_attention = torch.mean(teacher_feat, [1], keepdim=True)
                student_channel_attention = torch.mean(student_feat, [2, 3])
                teacher_channel_attention = torch.mean(teacher_feat, [2, 3])
                self.loss_attn += torch.dist(self.spatial_wise_adaptation[index](student_spatial_attention), teacher_spatial_attention, p=2)
                self.loss_attn += torch.dist(self.channel_wise_adaptation[index](student_channel_attention), teacher_channel_attention, p=2)

        self.loss_attn *= self.opt.kd4
'''