# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import sys
from typing import Iterable

import torch
from torch import nn

import util.misc as misc
import util.lr_sched as lr_sched
from torch.nn import functional as F
from eval import add_imgs, imgs_grid
from torch.distributions.categorical import Categorical

CE_loss = nn.CrossEntropyLoss().cuda()


def update_average(model_tgt, model_src, beta=0.99):
    with torch.no_grad():
        param_dict_src = model_src.state_dict()
        param_dict_tgt = model_tgt.state_dict()

        for p_name, p_tgt in param_dict_tgt.items():
            # print(p_name)
            p_src = param_dict_src[p_name]
            assert (p_src is not p_tgt)
            # if p_name.find('running')>0:
            #     param_dict_tgt[p_name] = p_src.data
            # else:
            #     param_dict_tgt[p_name] = beta*p_tgt.data + (1. - beta)*p_src.data
            param_dict_tgt[p_name] = beta * p_tgt.data + (1. - beta) * p_src.data

        model_tgt.load_state_dict(param_dict_tgt)
    return model_tgt


def compute_loss(d_out, target):
    gan_type = 'mse'
    targets = d_out.new_full(size=d_out.size(), fill_value=target)
    if gan_type == 'mse':
        loss = F.mse_loss(d_out, targets)
    elif gan_type == 'std':
        loss = F.binary_cross_entropy_with_logits(d_out, targets)
    return loss


def GAN_D_loss(logi_real, logi_fake, real_label, fake_label, mask_s=None, mask_t=None,
               dis_net=None, real_imgs=None, fake_imgs=None, lossType=None, phi=1., label=None, is_GP=True):
    if lossType == 'hinge':
        d_loss = torch.mean(nn.ReLU(inplace=True)(1.0 - logi_real)) + \
                 torch.mean(nn.ReLU(inplace=True)(1 + logi_fake))
    elif lossType == 'standard':
        d_real_loss = nn.BCEWithLogitsLoss()(logi_real, real_label)
        d_fake_loss = nn.BCEWithLogitsLoss()(logi_fake, fake_label)
        if is_GP:
            gradients_real = torch.autograd.grad(
                outputs=logi_real,
                inputs=real_imgs,
                grad_outputs=torch.ones([real_imgs.shape[0], ], requires_grad=False).to(real_imgs.get_device()),
                create_graph=True,
                retain_graph=True,
                only_inputs=True,
            )[0]
            gradients_fake = torch.autograd.grad(
                outputs=logi_fake,
                inputs=fake_imgs,
                grad_outputs=torch.ones([fake_imgs.shape[0], ], requires_grad=False).to(fake_imgs.get_device()),
                create_graph=True,
                retain_graph=True,
                only_inputs=True,
            )[0]
            gradient_penalty = 0.5 * (gradients_real.pow(2) + gradients_fake.pow(2)).reshape(
                gradients_real.size(0), -1).sum(-1).mean()
            # gradient_penalty = gradients_real.pow(2).view(real_imgs.shape[0], -1).sum(1).mean()

            #  gradient_penalty =compute_grad2(logi_real, real_imgs).mean()
            d_loss = d_real_loss + d_fake_loss + 10. * gradient_penalty
        else:
            d_loss = d_real_loss + d_fake_loss
    elif lossType == 'lsgan':
        d_real_loss = nn.MSELoss()(logi_real, real_label)
        d_fake_loss = nn.MSELoss()(logi_fake, fake_label)
        d_loss = d_real_loss + d_fake_loss
    else:
        raise NotImplementedError(lossType)
    return d_loss


def GAN_G_loss(logi_fake, real_label, lossType=None, is_fake=False):
    if lossType == 'hinge':
        if is_fake:
            g_loss = torch.mean(logi_fake)
        else:
            g_loss = -torch.mean(logi_fake)
    elif lossType == "standard":
        g_loss = nn.BCEWithLogitsLoss()(logi_fake.reshape(-1), real_label.reshape(-1))
    elif lossType == "vanilla":
        g_loss = -nn.BCEWithLogitsLoss()(logi_fake.reshape(-1), 1. - real_label.reshape(-1))
    elif lossType == "lsgan":
        g_loss = nn.MSELoss()(logi_fake, real_label)
    elif lossType == 'wgangp':
        if is_fake:  # for comunicate
            g_loss = torch.mean(logi_fake)
        else:
            g_loss = - torch.mean(logi_fake)

    return g_loss


def set_requires_grad(net, requires_grad=False):
    for param in net.parameters():
        param.requires_grad = requires_grad


def set_requires_grad_D(net, requires_grad=False):
    for n, param in net.named_parameters():
        if n.find('pos_embed') < 0:
            param.requires_grad = requires_grad


# lambdaGP = 10.  # 10.


def train_one_epoch(gen_net: torch.nn.Module, dis_net: torch.nn.Module, data_loader: Iterable,
                    optimizer_G: torch.optim.Optimizer, optimizer_D: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler,
                    log_writer=None, jiaohu=False, lambdaGP=10.,
                    args=None, avg_gen_net=None, scheduler_G=None, scheduler_D=None,
                    gen_z_test=None, test_data=None, test_y=None):
    ''' train a spicific s/t'''
    len_s = None

    gen_net.train(True)
    dis_net.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('globalIT', misc.SmoothedValue(window_size=1, fmt='{value:6d}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = args.print_freq

    accum_iter = args.accum_iter
    d_iter = args.d_iter

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))
    # define labels
    if args.loss in ['standard', 'vanilla']:
        real_label = torch.full((args.batch_size,), 1., dtype=torch.float, device=device)
        fake_label = torch.full((args.batch_size,), 0., dtype=torch.float, device=device)
    elif args.loss == 'lsgan':
        real_label = torch.full((args.batch_size,), 1., dtype=torch.float, device=device)
        fake_label = torch.full((args.batch_size,), -1., dtype=torch.float, device=device)
    else:
        real_label = None
        fake_label = None
    # define mask type
    if args.method == 'all':
        type_list = ['generate', 'classify']  # 'conditional_generate'
        type_prob = [0.6, 0.4]

    # begin training
    h_bs = args.batch_size // 2
    for data_iter_step, (samples, y) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

        # do not use label conditional generation
        y = y // 9999999

        global_train_step = data_iter_step + len(data_loader) * epoch
        metric_logger.update(globalIT=global_train_step)

        # we use a per iteration (instead of per epoch) lr scheduler
        if (args.set_scheduler is 'mae') and (data_iter_step % accum_iter == 0):
            # method 1, warm up
            lr_sched.adjust_learning_rate(optimizer_G, data_iter_step / len(data_loader) + epoch, args)
            lr_sched.adjust_learning_rate(optimizer_D, data_iter_step / len(data_loader) + epoch, args)
        if args.set_scheduler is 'linear':
            # method 2, no warm up
            scheduler_G.step(global_train_step)
            scheduler_D.step(global_train_step)
        if (args.set_scheduler=='warmupC') and (data_iter_step % accum_iter == 0):
            # method 3, warm up
            lr_sched.adjust_learning_rate_C(optimizer_G, data_iter_step / len(data_loader) + epoch, args)
            lr_sched.adjust_learning_rate_C(optimizer_D, data_iter_step / len(data_loader) + epoch, args)
            # print('Glr:', optimizer_G.param_groups[0]['lr'], 'Dlr:', optimizer_D.param_groups[0]['lr'])



        for param_group in optimizer_G.param_groups:
            log_writer.add_scalar('optimizer/lr', param_group["lr"], global_train_step)

        real_imgs = samples.to(device, non_blocking=True)
        if args.setGP in ['real', 'all']:
            real_imgs = real_imgs.requires_grad_(True)
        y = y.to(device)
        gen_z = torch.randn(args.batch_size, 1, args.z_dim).to(device)

        # ------- define a method name -------
        if args.method == 'all':  # choose each case randomly
            method = type_list[Categorical(probs=torch.tensor(type_prob)).sample()]
        else:
            method = args.method

        # ------- ------- ------- ---- -------
        if method == 'generate':
            # ---------------------
            #  Train Discriminator

            for itmp in range(args.Dsteps):
                # ---------- D0: zerograd ----------
                set_requires_grad_D(dis_net, requires_grad=True)  # embed
                optimizer_D.zero_grad()

                # ---------- D1: match model to data ----------
                if not args.communicate:
                    Sratio = args.SampleSratio.sample() if args.Sratio < 0. else args.Sratio
                    Tratio = args.SampleTratio.sample() if args.Tratio < 0. else args.Tratio
                    fake_imgs_g, _, mask_s, mask_t, _ = gen_net(real_imgs, gen_z, label=y,
                                                                Sratio=Sratio, Tratio=Tratio, method=method)
                else:
                    # Must have  S \cup T = CommuS \cup CommuT
                    Sratio = args.SampleSratio.sample() if args.Sratio < 0. else args.Sratio
                    Tratio = args.SampleTratio.sample() if args.Tratio < 0. else args.Tratio
                    CommuSratio = args.SampleCommuSratio.sample() if args.CommuSratio < 0. else args.CommuSratio

                    fake_imgs1, _, mask_s1, mask_t1, r_ind = gen_net(real_imgs[:h_bs, ...], gen_z[:h_bs, ...],
                                                                     label=y[:h_bs, ...],
                                                                     Sratio=Sratio, Tratio=Tratio, method=method)
                    fake_imgs2, _, mask_s2, mask_t2, _ = gen_net(real_imgs[h_bs:, ...], gen_z[h_bs:, ...],
                                                                 label=y[h_bs:, ...],
                                                                 Sratio=CommuSratio, Tratio=-10., method=method,
                                                                 r_ind=r_ind,
                                                                 Lst=int(mask_s1[0].sum() + mask_t1[0].sum()))

                    fake_imgs_g = torch.cat((fake_imgs1, fake_imgs2), dim=0)
                    mask_s = torch.cat((mask_s1, mask_s2), dim=0)
                    mask_t = torch.cat((mask_t1, mask_t2), dim=0)

                logi_real = dis_net(real_imgs, mask_s, mask_t, label=y, method=method,
                                    D_no_padST=args.D_no_padST)  # simlification zenmezuo??
                d_real_loss = nn.BCEWithLogitsLoss()(logi_real, real_label)
                log_writer.add_scalar('dloss/d_loss_real', d_real_loss.item(), global_train_step)

                if args.setGP in ['real', 'all']:
                    d_real_loss.backward(retain_graph=True)

                    gradients_real = torch.autograd.grad(
                        outputs=logi_real,
                        inputs=real_imgs,
                        grad_outputs=torch.ones([real_imgs.shape[0], ], requires_grad=False).to(real_imgs.get_device()),
                        create_graph=True,
                        retain_graph=True,
                        only_inputs=True,
                    )[0]
                    gradient_penalty_real = lambdaGP * 0.5 * (gradients_real.pow(2)).reshape(
                        gradients_real.size(0), -1).sum(-1).mean()
                    gradient_penalty_real.backward(retain_graph=False)
                    torch.cuda.empty_cache()
                    log_writer.add_scalar('dloss/GPreal', gradient_penalty_real.item(), global_train_step)

                else:
                    d_real_loss.backward(retain_graph=False)
                    torch.cuda.empty_cache()

                # if args.setGP in ['all']:
                fake_imgs = fake_imgs_g.detach().data
                fake_imgs.requires_grad_(True)
                logi_fake = dis_net(fake_imgs, mask_s, mask_t, label=y, method=method, D_no_padST=args.D_no_padST)
                # else:
                #     logi_fake = dis_net(fake_imgs_g.detach(), mask_s, mask_t, label=y, method=method, D_no_padST=args.D_no_padST)
                d_fake_loss = nn.BCEWithLogitsLoss()(logi_fake, fake_label)

                d_loss = d_real_loss.data + d_fake_loss.data
                log_writer.add_scalar('dloss/d_loss_fake', d_fake_loss.item(), global_train_step)
                log_writer.add_scalar('dloss/d_loss', d_loss.item(), global_train_step)

                if args.setGP in ['all']:
                    d_fake_loss.backward(retain_graph=True)

                    gradients_fake = torch.autograd.grad(
                        outputs=logi_fake,
                        inputs=fake_imgs,
                        grad_outputs=torch.ones([fake_imgs.shape[0], ], requires_grad=False).to(fake_imgs.get_device()),
                        create_graph=True,
                        retain_graph=True,
                        only_inputs=True,
                    )[0]
                    gradient_penalty_fake = lambdaGP * 0.5 * (gradients_fake.pow(2)).reshape(
                        gradients_fake.size(0), -1).sum(-1).mean()
                    gradient_penalty_fake.backward(retain_graph=False)
                    torch.cuda.empty_cache()

                    log_writer.add_scalar('dloss/GPfake', gradient_penalty_fake.item(), global_train_step)
                    # d_loss += gradient_penalty_real.data + gradient_penalty_fake.data

                else:
                    d_fake_loss.backward(retain_graph=False)
                    torch.cuda.empty_cache()

                # ---------- D2: match model to model ----------
                if args.communicate:

                    commu_real_imgs = torch.cat((fake_imgs1.data, fake_imgs1.data), dim=0)
                    if args.setGP in ['real', 'all']:
                        commu_real_imgs.requires_grad_(True)
                    commu_logi_real = dis_net(commu_real_imgs, mask_s, mask_t, label=y, method=method,
                                              D_no_padST=args.D_no_padST)

                    logi_s1t1 = commu_logi_real[h_bs:, ...] - commu_logi_real[:h_bs, ...]
                    d_real_loss1 = nn.BCEWithLogitsLoss()(logi_s1t1, real_label[:h_bs, ...])
                    log_writer.add_scalar('dloss/commu_d_loss_real', d_real_loss1.item(), global_train_step)

                    if args.setGP in ['real', 'all']:
                        d_real_loss1.backward(retain_graph=True)

                        # RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
                        gradients_real1 = torch.autograd.grad(
                            outputs=logi_s1t1,
                            inputs=commu_real_imgs,
                            grad_outputs=torch.ones([fake_imgs1.shape[0], ], requires_grad=False).to(
                                real_imgs.get_device()),
                            create_graph=True,
                            retain_graph=True,
                            only_inputs=True,
                        )[0]
                        gradient_penalty_real1 = lambdaGP * 0.5 * (gradients_real1[:h_bs, ...].pow(2)).reshape(
                            gradients_real1[:h_bs, ...].size(0), -1).sum(-1).mean()
                        gradient_penalty_real1.backward(retain_graph=False)
                        torch.cuda.empty_cache()

                        log_writer.add_scalar('dloss/commu_GPreal', gradient_penalty_real1.item(), global_train_step)

                    else:
                        d_real_loss1.backward(retain_graph=False)
                        torch.cuda.empty_cache()

                    commu_fake_imgs = torch.cat((fake_imgs2.data, fake_imgs2.data), dim=0)
                    if args.setGP in ['all']:
                        commu_fake_imgs.requires_grad_(True)
                    commu_logi_fake = dis_net(commu_fake_imgs, mask_s, mask_t, label=y, method=method,
                                              D_no_padST=args.D_no_padST)

                    logi_s2t2 = commu_logi_fake[h_bs:, ...] - commu_logi_fake[:h_bs, ...]
                    d_fake_loss1 = nn.BCEWithLogitsLoss()(logi_s2t2, fake_label[h_bs:, ...])

                    d_comu_loss = d_real_loss1.data + d_fake_loss1.data
                    d_comu_loss /= accum_iter

                    log_writer.add_scalar('dloss/commu_d_loss_fake', d_fake_loss1.item(), global_train_step)
                    log_writer.add_scalar('dloss/commu_d_loss', d_comu_loss.item(), global_train_step)

                    if args.setGP in ['all']:
                        d_fake_loss1.backward(retain_graph=True)

                        gradients_fake1 = torch.autograd.grad(
                            outputs=logi_s2t2,
                            inputs=commu_fake_imgs,
                            grad_outputs=torch.ones([fake_imgs2.shape[0], ], requires_grad=False).to(
                                fake_imgs.get_device()),
                            create_graph=True,
                            retain_graph=True,
                            only_inputs=True,
                        )[0]
                        gradient_penalty_fake1 = lambdaGP * 0.5 * (gradients_fake1[h_bs:, ...].pow(2)).reshape(
                            gradients_fake1[h_bs:, ...].size(0), -1).sum(-1).mean()
                        # gradient_penalty_fake1(retain_graph=True)
                        gradient_penalty_fake1.backward(retain_graph=False)
                        torch.cuda.empty_cache()

                        log_writer.add_scalar('dloss/commu_GPfake', gradient_penalty_fake1.item(), global_train_step)

                    else:
                        d_fake_loss1.backward(retain_graph=False)
                        torch.cuda.empty_cache()

                # ---------- D3: D step ----------
                if args.clip_gradD > 0.:
                    torch.nn.utils.clip_grad_norm_(dis_net.parameters(), args.clip_gradD)
                optimizer_D.step()

                d_loss /= accum_iter
                metric_logger.update(d_loss=d_loss.item())

            # -----------------
            #  Train Generator

            # ---------- G0: zerograd ----------
            set_requires_grad_D(dis_net, requires_grad=False)  # embed
            optimizer_G.zero_grad()

            # ---------- G1: model to data ----------
            logi_fake = dis_net(fake_imgs_g, mask_s, mask_t, label=y, method=method, D_no_padST=args.D_no_padST)
            g_loss = GAN_G_loss(logi_fake, real_label, lossType=args.loss)
            if args.communicate:
                g_loss.backward(retain_graph=True)
            else:
                g_loss.backward(retain_graph=False)
                torch.cuda.empty_cache()

            metric_logger.update(g_loss=g_loss.item())
            log_writer.add_scalar('gloss/g_loss', g_loss.item(), global_train_step)

            # ---------- G2: model to model ----------
            if args.communicate:
                if not args.commu_only_fake:
                    commu_real_imgs = torch.cat((fake_imgs1, fake_imgs1), dim=0)
                    commu_logi_real = dis_net(commu_real_imgs, mask_s, mask_t, label=y, method=method,
                                              D_no_padST=args.D_no_padST)
                    logi_s1t1 = commu_logi_real[h_bs:, ...] - commu_logi_real[:h_bs, ...]

                    g_comu_loss1 = GAN_G_loss(logi_s1t1, fake_label[:h_bs], lossType=args.loss, is_fake=True)
                    g_comu_loss1.backward(retain_graph=True)
                else:
                    g_comu_loss1 = torch.zeros(1, device=g_loss.device)

                commu_fake_imgs = torch.cat((fake_imgs2, fake_imgs2), dim=0)
                commu_logi_fake = dis_net(commu_fake_imgs, mask_s, mask_t, label=y, method=method,
                                          D_no_padST=args.D_no_padST)
                logi_s2t2 = commu_logi_fake[h_bs:, ...] - commu_logi_fake[:h_bs, ...]

                g_comu_loss2 = GAN_G_loss(logi_s2t2, real_label[:h_bs], lossType=args.loss, is_fake=True)
                g_comu_loss2.backward(retain_graph=False)
                torch.cuda.empty_cache()

                g_comu_loss = g_comu_loss1.data + g_comu_loss2.data
                metric_logger.update(g_comu_loss=g_comu_loss.item())
                log_writer.add_scalar('gloss/g_loss_commu1', g_comu_loss1.item(), global_train_step)
                log_writer.add_scalar('gloss/g_loss_commu2', g_comu_loss2.item(), global_train_step)

            # ---------- G3: G step ----------
            # if (data_iter_step + 1) % accum_iter == 0:
            if args.clip_gradG > 0.:
                torch.nn.utils.clip_grad_norm_(gen_net.parameters(), args.clip_gradG)
            optimizer_G.step()

            if not math.isfinite(g_loss.item()):
                print("Loss is {}, stopping training".format(g_loss.item()))
                sys.exit(1)

        # ---------------------
        #  Train classification
        # ---------------------
        if method in ['classify']:

            Sratio = args.SampleSratio.sample() + 0.1 if args.Sratio < 0. else args.Sratio
            Tratio = args.SampleTratio.sample() if args.Tratio < 0. else args.Tratio

            _, cls_token_out, _, _, _ = gen_net(real_imgs, Sratio=Sratio, Tratio=Tratio, method=method)
            logi_fake = gen_net.head(cls_token_out.squeeze(1))
            c_loss = CE_loss(logi_fake, y)

            c_loss /= accum_iter
            c_loss.backward(retain_graph=False)
            # loss_scaler(c_loss, optimizer_G, clip_grad=args.clip_grad, parameters=gen_net.parameters(), retain_graph=False,
            #             update_grad=(data_iter_step + 1) % accum_iter == 0)
            metric_logger.update(c_loss=c_loss.item())

            if (data_iter_step + 1) % accum_iter == 0:
                if args.clip_gradG > 0.:
                    torch.nn.utils.clip_grad_norm_(gen_net.parameters(), args.clip_gradG)
                optimizer_G.step()
                optimizer_G.zero_grad()

        # torch.cuda.synchronize()
        metric_logger.update(lr=optimizer_G.param_groups[0]["lr"])

        # moving average weight
        if avg_gen_net is not None:
            avg_gen_net = update_average(avg_gen_net, gen_net, beta=0.9999)

        # 1 ------- image generation -------
        if (global_train_step % args.showFreq == 0) and (method in ['generate']):
            with torch.no_grad():
                fake_imgs_0, _, _, _, _ = gen_net(test_data, gen_z_test, label=test_y, Sratio=0., Tratio=1.,
                                                  method='generate')
                fake_imgs_1, _, _, _, _ = gen_net(test_data, gen_z_test, label=test_y, Sratio=0.1, Tratio=1.,
                                                  method='generate')
                fake_imgs = torch.cat((fake_imgs_0[:10, ...], fake_imgs_1[:10, ...]), dim=0)
                add_imgs(fake_imgs.data, args.dir_img + '-1.jpg', nrow=5)
                fake_imgs_0, _, _, _, _ = avg_gen_net(test_data, gen_z_test, label=test_y, Sratio=0., Tratio=1.,
                                                      method='generate')
                fake_imgs_1, _, _, _, _ = avg_gen_net(test_data, gen_z_test, label=test_y, Sratio=0.1, Tratio=1.,
                                                      method='generate')
                fake_imgs = torch.cat((fake_imgs_0[:10, ...], fake_imgs_1[:10, ...]), dim=0)
                add_imgs(fake_imgs.data, args.dir_img + '-2.jpg', nrow=5)
                add_imgs(fake_imgs.data, args.dir_img + str(global_train_step) + '.jpg', nrow=5)
            if global_train_step % 500 == 0:
                log_writer.add_image('out_img/gen_train', imgs_grid(fake_imgs.data, 5), global_train_step)

            torch.cuda.empty_cache()
    # save images
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)

    return avg_gen_net, gen_net, dis_net
