import os

import torch
import torch.autograd
import torch.nn as nn
import torch.optim as optim

import kpn.utils as kpn_utils

from .networks import Generator, EdgeGenerator, Discriminator
from .loss import AdversarialLoss, PerceptualLoss, StyleLoss


class BaseModel(nn.Module):
    def __init__(self, name, config, logger):
        super(BaseModel, self).__init__()

        self.name = name
        self.config = config
        self.logger = logger
        self.iteration = 0

        # self.gen_weights_path = config.MODEL_LOAD + 'gen.pth'
        # self.dis_weights_path = config.MODEL_LOAD + 'dis.pth'

    def load(self, gen_ckpt, dis_ckpt, amend_loading=False):
        assert os.path.exists(gen_ckpt), \
            'Provided MISF generator checkpoint does not exist.'
        self.logger.log('Loading %s generator...' % self.name)

        data = torch.load(gen_ckpt, map_location=self.config.DEVICE)

        generator_ckpt = data['generator']
        if amend_loading:
            half_weight_size = generator_ckpt['branch_1.decoder_2.1.weight'].size()

            # amend weights
            half_of_decoder2_weight = torch.normal(mean=0.0, std=0.02, size=half_weight_size).to(self.config.DEVICE)
            generator_ckpt['branch_1.decoder_2.1.weight'] = torch.cat([generator_ckpt['branch_1.decoder_2.1.weight'], half_of_decoder2_weight], dim=0)

            # amend bias
            half_bias_size = generator_ckpt['branch_1.decoder_2.1.bias'].size()
            generator_ckpt['branch_1.decoder_2.1.bias'] = torch.cat(
                [generator_ckpt['branch_1.decoder_2.1.bias'], 
                torch.zeros(half_bias_size, device=self.config.DEVICE)]
            )

        self.generator.load_state_dict(generator_ckpt, strict=True)
        self.iteration = data['iteration']

        # discriminator is utilized only for training
        if self.config.MODE == 'train' and os.path.exists(dis_ckpt):
            self.logger.log('Loading %s discriminator...' % self.name)

            data = torch.load(self.dis_weights_path, map_location=self.config.DEVICE)
            self.discriminator.load_state_dict(data['discriminator'])

    def save(self):

        if isinstance(self.config.GPU, list):
            generate_param = self.generator.module.state_dict()
            dis_param = self.discriminator.module.state_dict()
            self.logger.log('save...multiple GPU')
        else:
            generate_param = self.generator.state_dict()
            dis_param = self.discriminator.state_dict()
            self.logger.log('save...single GPU')

        torch.save({
            'iteration': self.iteration,
            'generator': generate_param
        }, os.path.join(self.config.CHECKPOINT_DEST, '{}_{}_gen.pth'.format(self.iteration, self.name)))

        torch.save({
            'discriminator': dis_param
        }, os.path.join(self.config.CHECKPOINT_DEST, '{}_{}_dis.pth'.format(self.iteration, self.name)))

        self.logger.log('\nsaving %s...\n' % self.name)


class EdgeModel(BaseModel):
    def __init__(self, config):
        super(EdgeModel, self).__init__('EdgeModel', config)

        # generator input: [grayscale(1) + edge(1) + mask(1)]
        # discriminator input: (grayscale(1) + edge(1))
        generator = EdgeGenerator(use_spectral_norm=True)
        discriminator = Discriminator(in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge')
        if isinstance(self.config.GPU, list):
            generator = nn.DataParallel(generator, config.GPU)
            discriminator = nn.DataParallel(discriminator, config.GPU)
        l1_loss = nn.L1Loss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)

        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)

        self.add_module('l1_loss', l1_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        self.gen_optimizer = optim.Adam(
            params=generator.parameters(),
            lr=float(config.LR),
            betas=(config.BETA1, config.BETA2)
        )

        self.dis_optimizer = optim.Adam(
            params=discriminator.parameters(),
            lr=float(config.LR) * float(config.D2G_LR),
            betas=(config.BETA1, config.BETA2)
        )

    def process(self, images, edges, masks):
        self.iteration += 1


        # zero optimizers
        self.gen_optimizer.zero_grad()
        self.dis_optimizer.zero_grad()


        # process outputs
        outputs = self(images, edges, masks)
        gen_loss = 0
        dis_loss = 0


        # discriminator loss
        dis_input_real = torch.cat((images, edges), dim=1)
        dis_input_fake = torch.cat((images, outputs.detach()), dim=1)
        dis_real, dis_real_feat = self.discriminator(dis_input_real)        # in: (grayscale(1) + edge(1))
        dis_fake, dis_fake_feat = self.discriminator(dis_input_fake)        # in: (grayscale(1) + edge(1))
        dis_real_loss = self.adversarial_loss(dis_real, True, True)
        dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
        dis_loss += (dis_real_loss + dis_fake_loss) / 2


        # generator adversarial loss
        gen_input_fake = torch.cat((images, outputs), dim=1)
        gen_fake, gen_fake_feat = self.discriminator(gen_input_fake)        # in: (grayscale(1) + edge(1))
        gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
        gen_loss += gen_gan_loss


        # generator feature matching loss
        gen_fm_loss = 0
        for i in range(len(dis_real_feat)):
            gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
        gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT
        gen_loss += gen_fm_loss


        # create logs
        logs = [
            ("l_d1", dis_loss.item()),
            ("l_g1", gen_gan_loss.item()),
            ("l_fm", gen_fm_loss.item()),
            ("l_gen", gen_loss.item()),
        ]

        return outputs, gen_loss, dis_loss, logs

    def forward(self, images, edges, masks):

        a = masks.numpy()
        zero = masks[masks == 0]
        one = masks[masks == 1]

        all = len(zero) + len(one)

        edges_masked = (edges * (1 - masks))
        images_masked = (images * (1 - masks)) + masks
        inputs = torch.cat((images_masked, edges_masked, masks), dim=1)
        outputs = self.generator(inputs)                                    # in: [grayscale(1) + edge(1) + mask(1)]
        return outputs

    def backward(self, gen_loss=None, dis_loss=None):
        if gen_loss is not None:
            gen_loss.backward()
        self.gen_optimizer.step()

        if dis_loss is not None:
            dis_loss.backward()
        self.dis_optimizer.step()


class InpaintingModel(BaseModel):
    def __init__(self, config, logger):
        super(InpaintingModel, self).__init__('InpaintingModel', config, logger)

        # generator input: [rgb(3) + edge(1)]
        # discriminator input: [rgb(3)]
        generator = Generator(config=config, logger=logger)
        discriminator = Discriminator(in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge')

        l1_loss = nn.L1Loss()
        perceptual_loss = PerceptualLoss()
        style_loss = StyleLoss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)

        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)

        self.add_module('l1_loss', l1_loss)
        self.add_module('perceptual_loss', perceptual_loss)
        self.add_module('style_loss', style_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        # backbone_para = generator.get_backbone_parameters()
        head_para = generator.get_confidence_head_parameters()

        self.gen_optimizer = optim.Adam(
            params=generator.parameters(),
            lr=float(config.LR),
            betas=(config.BETA1, config.BETA2)
        )

        self.dis_optimizer = optim.Adam(
            params=discriminator.parameters(),
            lr=float(config.LR) * float(config.D2G_LR),
            betas=(config.BETA1, config.BETA2)
        )

    def get_loss_with_diff(self, images, diff_output, misf_output, masks):
        dis_loss = 0
        misf_gen_loss = 0
        diff_gen_loss = 0

        # discriminator loss
        dis_input_real = images
        dis_input_fake = misf_output.detach()
        dis_real, _ = self.discriminator(dis_input_real)  # in: [rgb(3)]
        dis_fake, _ = self.discriminator(dis_input_fake)  # in: [rgb(3)]
        dis_real_loss = self.adversarial_loss(dis_real, True, True)
        dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
        dis_loss += (dis_real_loss + dis_fake_loss) / 2

        # generator adversarial loss
        gen_input_fake = misf_output.detach()
        gen_fake, _ = self.discriminator(gen_input_fake)  # in: [rgb(3)]
        gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT
        misf_gen_loss += gen_gan_loss

        # generator l1 loss
        gen_l1_loss = self.l1_loss(diff_output.detach(), images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks)
        diff_gen_loss += gen_l1_loss

        # generator perceptual loss
        gen_content_loss = self.perceptual_loss(diff_output.detach(), images)
        gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT
        diff_gen_loss += gen_content_loss

        # generator style loss
        gen_style_loss = self.style_loss(diff_output.detach() * masks, images * masks)
        gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT
        diff_gen_loss += gen_style_loss

        # create logs
        logs = [
            "d:{:.4f}".format(dis_loss.item()),
            "g:{:.4f}".format(gen_gan_loss.item()),
            "l1:{:.4f}".format(gen_l1_loss.item()),
            "per:{:.4f}".format(gen_content_loss.item()),
            "sty:{:.4f}".format(gen_style_loss.item())
        ]

        return misf_gen_loss, diff_gen_loss, dis_loss, logs

    def get_wo_gan_loss(self, images, outputs, masks):
        gen_loss = 0

        # generator l1 loss
        gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks)
        gen_loss += gen_l1_loss

        # generator perceptual loss
        gen_content_loss = self.perceptual_loss(outputs, images)
        gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT
        gen_loss += gen_content_loss

        # generator style loss
        gen_style_loss = self.style_loss(outputs * masks, images * masks)
        gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT
        gen_loss += gen_style_loss

        # create logs
        logs = [
            "diff_l1:{:.4f}".format(gen_l1_loss.item()),
            "diff_per:{:.4f}".format(gen_content_loss.item()),
            "diff_sty:{:.4f}".format(gen_style_loss.item())
        ]

        return gen_loss, logs

    def get_loss(self, images, masks, outputs, learn_confidence):
        gen_loss = 0
        dis_loss = 0

        # discriminator loss
        dis_input_real = images
        dis_input_fake = outputs.detach()
        dis_real, _ = self.discriminator(dis_input_real)  # in: [rgb(3)]
        dis_fake, _ = self.discriminator(dis_input_fake)  # in: [rgb(3)]
        dis_real_loss = self.adversarial_loss(dis_real, True, True)
        dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
        dis_loss += (dis_real_loss + dis_fake_loss) / 2

        # generator adversarial loss
        gen_input_fake = outputs
        gen_fake, _ = self.discriminator(gen_input_fake)  # in: [rgb(3)]
        gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT
        gen_loss += gen_gan_loss

        # generator l1 loss
        gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT 
        gen_l1_loss = gen_l1_loss if learn_confidence else (gen_l1_loss / torch.mean(masks))
        gen_loss += gen_l1_loss

        # generator perceptual loss
        gen_content_loss = self.perceptual_loss(outputs, images)
        gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT
        gen_loss += gen_content_loss

        # generator style loss
        if learn_confidence:
            gen_style_loss = self.style_loss(outputs, images)
        else:
            gen_style_loss = self.style_loss(outputs * masks, images * masks)
        gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT
        gen_loss += gen_style_loss

        # create logs
        logs = [
            "d:{:.4f}".format(dis_loss.item()),
            "g:{:.4f}".format(gen_gan_loss.item()),
            "l1:{:.4f}".format(gen_l1_loss.item()),
            "per:{:.4f}".format(gen_content_loss.item()),
            "sty:{:.4f}".format(gen_style_loss.item())
        ]

        return gen_loss, dis_loss, logs

    def process(self, images1, images2, masks1, masks2, learn_confidence):
        self.iteration += 1

        result = self(images1, images2, masks1, masks2, learn_confidence=learn_confidence)
        if learn_confidence:
            outputs1, outputs2, confidence1, confidence2 = result

            inpaint_gen_loss_1, inpaint_dis_loss_1, inpaint_logs_1 = self.get_loss(images1, masks1, outputs1, False)
            inpaint_gen_loss_2, inpaint_dis_loss_2, inpaint_logs_2 = self.get_loss(images2, masks2, outputs2, False)
            conf_gen_loss_1, conf_dis_loss_1, conf_logs_1 = self.get_loss(images1-outputs1.detach(), masks1, confidence1, True)
            conf_gen_loss_2, conf_dis_loss_2, conf_logs_2 = self.get_loss(images2-outputs2.detach(), masks2, confidence2, True)

            gen_loss_1, dis_loss_1, logs_1 = \
                inpaint_gen_loss_1 + conf_gen_loss_1, \
                inpaint_dis_loss_1+ conf_dis_loss_1, \
                inpaint_logs_1 + conf_logs_1
            gen_loss_2, dis_loss_2, logs_2 = \
                inpaint_gen_loss_2 + conf_gen_loss_2, \
                inpaint_dis_loss_2 + conf_dis_loss_2, \
                inpaint_logs_2 + conf_logs_2
        else:
            outputs1, outputs2 = result
            gen_loss_1, dis_loss_1, logs_1 = self.get_loss(images1, masks1, outputs1, False)
            gen_loss_2, dis_loss_2, logs_2 = self.get_loss(images2, masks2, outputs2, False)

        gen_loss = gen_loss_1 + gen_loss_2
        dis_loss = dis_loss_1 + dis_loss_2

        logs = logs_1 + logs_2

        return result, gen_loss, dis_loss, logs

    def forward(self, images_1, images_2, masks_1, masks_2, learn_confidence, return_kernel=False, return_feature=False):

        images_masked_1 = images_1 * (1 - masks_1)
        images_masked_2 = images_2 * (1 - masks_2)

        inputs_1 = torch.cat((images_masked_1, masks_1), dim=1)
        inputs_2 = torch.cat((images_masked_2, masks_2), dim=1)

        output_1, output_2, confidence_1, confidence_2, kernels_f, kernels_i, kernels_c, features = self.generator(inputs_1, inputs_2)

        return_list = [output_1, output_2]
        if learn_confidence:
            return_list += [confidence_1, confidence_2]
        if return_kernel:
            return_list += [kernels_f, kernels_i, kernels_c]
        if return_feature:
            return_list += [features]

        return return_list

    def backward(self, gen_loss=None, dis_loss=None):
        # zero optimizers
        self.gen_optimizer.zero_grad()
        self.dis_optimizer.zero_grad()
        
        gen_loss.backward()
        self.gen_optimizer.step()

        dis_loss.backward()
        self.dis_optimizer.step()