import torch
import models.networks as networks
import util.util as util
import cv2 as cv
import torch.nn.functional as F
import numpy as np

class SmisModel(torch.nn.Module):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        networks.modify_commandline_options(parser, is_train)
        return parser

    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(
                opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids, self.opt.vgg_path)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()

    # Entry point for all calls involving forward pass
    # of deep networks. We used this approach since DataParallel module
    # can't parallelize custom functions, we branch to different
    # routines based on |mode|.
    def forward(self, data, mode):
        # input_semantics, real_image  = self.preprocess_input(data)
        input_semantics, real_image = self.preprocess_input(data)

        if mode == 'generator':
            g_loss, generated = self.compute_generator_loss(
                input_semantics, real_image)
            return g_loss, generated
        elif mode == 'discriminator':
            d_loss = self.compute_discriminator_loss(
                input_semantics, real_image)
            return d_loss
        elif mode == 'encode_only':
            z, mu, logvar = self.encode_z(real_image)
            return mu, logvar
        elif mode == 'inference':
            with torch.no_grad():
                if self.opt.test_mask != -1:
                    fake_image = self.vis_test(input_semantics, times=self.opt.test_times, test_mask=self.opt.test_mask)
                else:
                    fake_image = self.vis_test(input_semantics, times=self.opt.test_times)
            return fake_image
        else:
            raise ValueError("|mode| is invalid")

    def create_optimizers(self, opt):
        G_params = list(self.netG.parameters())
        if opt.use_vae:
            G_params += list(self.netE.parameters())
            # G_params += list(self.netE_edge.parameters())
        if opt.isTrain:
            D_params = list(self.netD.parameters())

        if opt.no_TTUR:
            beta1, beta2 = opt.beta1, opt.beta2
            G_lr, D_lr = opt.lr, opt.lr
        else:
            beta1, beta2 = 0, 0.9
            G_lr, D_lr = opt.lr / 2, opt.lr * 2

        optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
        optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2))

        return optimizer_G, optimizer_D

    def save(self, epoch):
        util.save_network(self.netG, 'G', epoch, self.opt)
        util.save_network(self.netD, 'D', epoch, self.opt)
        if self.opt.use_vae:
            util.save_network(self.netE, 'E', epoch, self.opt)
            # util.save_network(self.netE_edge, 'E_edge', epoch, self.opt)

    ############################################################################
    # Private helper methods
    ############################################################################

    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        # if not opt.isTrain:
        #     print(netG)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae and opt.isTrain else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                if opt.use_vae:
                    netE = util.load_network(netE, 'E', opt.which_epoch, opt)
                # netE_edge = util.load_network(netE_edge, '')

        return netG, netD, netE

    # preprocess the input, such as moving the tensors to GPUs and
    # transforming the label map to one-hot encoding
    # |data|: dictionary of the input data
    def preprocess_input(self, data):
        # move to GPU and change data types
        data['label'] = data['label'].long()
        if self.use_gpu():
            data['label'] = data['label'].cuda()
            data['instance'] = data['instance'].cuda()
            data['image'] = data['image'].cuda()
            # data['edge'] = data['edge'].cuda()

        # create one-hot label map
        label_map = data['label']
        bs, _, h, w = label_map.size()
        nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \
            else self.opt.label_nc
        input_label = self.FloatTensor(bs, nc, h, w).zero_()
        input_semantics = input_label.scatter_(1, label_map, 1.0)

        # concatenate instance map if it exists
        if not self.opt.no_instance:
            inst_map = data['instance']
            instance_edge_map = self.get_edges(inst_map)
            input_semantics = torch.cat((input_semantics, instance_edge_map), dim=1)

        return input_semantics, data['image']

    def compute_generator_loss(self, input_semantics, real_image):
        G_losses = {}

        fake_image, KLD_loss, CODE_loss = self.generate_fake(
            input_semantics, real_image, compute_kld_loss=self.opt.use_vae)

        if self.opt.use_vae:
            G_losses['KLD'] = KLD_loss
            # G_losses['CODE'] = CODE_loss

        pred_fake, pred_real = self.discriminate(
            input_semantics, fake_image, real_image)

        G_losses['GAN'] = self.criterionGAN(pred_fake, True,
                                            for_discriminator=False)

        if not self.opt.no_ganFeat_loss:
            num_D = len(pred_fake)
            GAN_Feat_loss = self.FloatTensor(1).fill_(0)
            for i in range(num_D):  # for each discriminator
                # last output is the final prediction, so we exclude it
                num_intermediate_outputs = len(pred_fake[i]) - 1
                for j in range(num_intermediate_outputs):  # for each layer output
                    unweighted_loss = self.criterionFeat(
                        pred_fake[i][j], pred_real[i][j].detach())
                    GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D
            G_losses['GAN_Feat'] = GAN_Feat_loss

        if not self.opt.no_vgg_loss:
            G_losses['VGG'] = self.criterionVGG(fake_image, real_image) \
                * self.opt.lambda_vgg

        return G_losses, fake_image

    def compute_discriminator_loss(self, input_semantics, real_image):
        D_losses = {}
        with torch.no_grad():
            fake_image, _, _ = self.generate_fake(input_semantics, real_image)
            fake_image = fake_image.detach()
            fake_image.requires_grad_()

        pred_fake, pred_real = self.discriminate(
            input_semantics, fake_image, real_image)

        D_losses['D_Fake'] = self.criterionGAN(pred_fake, False,
                                               for_discriminator=True)
        D_losses['D_real'] = self.criterionGAN(pred_real, True,
                                               for_discriminator=True)

        return D_losses

    def encode_z(self, real_image):
        mu, logvar = self.netE(real_image)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def trans_img(self, input_semantics, real_image):
        images = None
        seg_range = input_semantics.size()[1]
        if self.opt.dataset_mode == 'cityscapes':
            seg_range -= 1
        for i in range(input_semantics.size(0)):
            resize_image = None
            for n in range(0, seg_range):
                seg_image = real_image[i] * input_semantics[i][n]
                # resize seg_image
                c_sum = seg_image.sum(dim=0)
                y_seg = c_sum.sum(dim=0)
                x_seg = c_sum.sum(dim=1)
                y_id = y_seg.nonzero()
                if y_id.size()[0] == 0:
                    seg_image = seg_image.unsqueeze(dim=0)
                    # resize_image = torch.cat((resize_image, seg_image), dim=0)
                    if resize_image is None:
                        resize_image = seg_image
                    else:
                        resize_image = torch.cat((resize_image, seg_image), dim=1)
                    continue
                # print(y_id)
                y_min = y_id[0][0]
                y_max = y_id[-1][0]
                x_id = x_seg.nonzero()
                x_min = x_id[0][0]
                x_max = x_id[-1][0]
                seg_image = seg_image.unsqueeze(dim=0)
                # print(x_min, x_max, y_min, y_max)
                if self.opt.dataset_mode == 'cityscapes':
                    seg_image = F.interpolate(seg_image[:, :, x_min:x_max + 1, y_min:y_max + 1], size=[256, 512])
                else:
                    seg_image = F.interpolate(seg_image[:, :, x_min:x_max + 1, y_min:y_max + 1], size=[256, 256])
                # seg_image = F.interpolate(seg_image[:, :, x_min:x_max + 1, y_min:y_max + 1], scale_factor=256 / max(y_max-y_min, x_max-x_min))
                # seg_image = F.interpolate(seg_image[:, :, x_min:x_max + 1, y_min:y_max + 1], size=[256, 256])
                if resize_image is None:
                    resize_image = seg_image
                else:
                    resize_image = torch.cat((resize_image, seg_image), dim=1)
            if images is None:
                images = resize_image
            else:
                images = torch.cat((images, resize_image), dim=0)
        return images

    def generate_fake(self, input_semantics, real_image, compute_kld_loss=False):
        z = None
        KLD_loss = None
        if self.opt.use_vae:
            images = self.trans_img(input_semantics, real_image)
            z, mu, logvar = self.encode_z(images)
            CODE_loss = None
            if compute_kld_loss:
                KLD_loss = self.KLDLoss(mu, logvar) * self.opt.lambda_kld
        fake_image = self.netG(input_semantics, z=z)

        assert (not compute_kld_loss) or self.opt.use_vae, \
            "You cannot compute KLD loss if opt.use_vae == False"

        return fake_image, KLD_loss, CODE_loss

    def vis_test(self, input_semantics, times=1, test_mask=None):
        fake_image = []
        if self.opt.dataset_mode == 'cityscapes':
            z = torch.randn(input_semantics.size(0), self.opt.label_nc, 8, 4 * 8).cuda()
            for i in range(times):
                if test_mask is not None:
                    z[:, test_mask, :, :] = torch.randn(input_semantics.size(0), 8, 4*8)
                else:
                    z = torch.randn(input_semantics.size(0), self.opt.label_nc, 8, 4 * 8).cuda()
                fake_image.append(
                    self.netG(input_semantics, z=z.view(input_semantics.size(0), self.opt.label_nc * 8, 4, 8)))
        else:
            z = torch.randn(input_semantics.size(0), self.opt.semantic_nc, 8, 4 * 4)
            for i in range(times):
                if test_mask is not None:
                    z[:, test_mask, :, :] = torch.randn(input_semantics.size(0), 8, 16)
                else:
                    z = torch.randn(input_semantics.size(0), self.opt.semantic_nc, 8, 4 * 4)
                fake_image.append(self.netG(input_semantics,
                                            z=z.view(input_semantics.size(0), self.opt.semantic_nc * 8, 4, 4).cuda()))
        return fake_image

    # Given fake and real image, return the prediction of discriminator
    # for each fake and real image.
    def discriminate(self, input_semantics, fake_image, real_image):
        fake_concat = torch.cat([input_semantics, fake_image], dim=1)
        real_concat = torch.cat([input_semantics, real_image], dim=1)

        # In Batch Normalization, the fake and real images are
        # recommended to be in the same batch to avoid disparate
        # statistics in fake and real images.
        # So both fake and real images are fed to D all at once.
        fake_and_real = torch.cat([fake_concat, real_concat], dim=0)

        discriminator_out = self.netD(fake_and_real)

        pred_fake, pred_real = self.divide_pred(discriminator_out)

        return pred_fake, pred_real

    # Take the prediction of fake and real images from the combined batch
    def divide_pred(self, pred):
        # the prediction contains the intermediate outputs of multiscale GAN,
        # so it's usually a list
        if type(pred) == list:
            fake = []
            real = []
            for p in pred:
                fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
                real.append([tensor[tensor.size(0) // 2:] for tensor in p])
        else:
            fake = pred[:pred.size(0) // 2]
            real = pred[pred.size(0) // 2:]

        return fake, real

    def get_edges(self, t):
        edge = self.ByteTensor(t.size()).zero_()
        edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]).byte()
        edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]).byte()
        edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]).byte()
        edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]).byte()
        return edge.float()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps.mul(std) + mu

    def use_gpu(self):
        return len(self.opt.gpu_ids) > 0