# PyTorch GALIP: https://github.com/tobran/GALIP
# The MIT License (MIT)
# See license file or visit https://github.com/tobran/GALIP for details

# replaced with code/lib/modules.py for SONA training

import os
import sys
from pyexpat import features
import os.path as osp
import time
import random
import datetime
import argparse
from scipy import linalg
import numpy as np
from PIL import Image
from tqdm import tqdm, trange
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchvision.utils import make_grid
from lib.utils import transf_to_CLIP_input, dummy_context_mgr
from lib.utils import mkdir_p, get_rank, save_all
from lib.datasets import prepare_data, prepare_data_with_text_aug

from models.inception import InceptionV3
from torch.nn.functional import adaptive_avg_pool2d
import torch.distributed as dist


############   GAN   ############
def train(dataloader, netG, netD, netC, text_encoder, image_encoder, optimizerG, optimizerD, scaler_G, scaler_D, args, writer, accelerator):
    batch_size = args.local_batch_size if "local_batch_size" in args else args.batch_size
    device = args.device
    epoch = args.current_epoch
    max_epoch = args.max_epoch
    z_dim = args.z_dim
    if ("c_d_fake" in args.keys()) and ("c_d_mis" in args.keys()):
        c_d_fake, c_d_mis = args.c_d_fake, args.c_d_mis
    else:
        c_d_fake = c_d_mis = 0.5
    if "wass_enh" in args.keys():
        wass_enh = torch.tensor(args.wass_enh, dtype=torch.float)
    else:
        wass_enh = torch.tensor(1.0, dtype=torch.float)
    netG, netD, netC, image_encoder = netG.train(), netD.train(), netC.train(), image_encoder.train()
    if not args.mixed_precision:
        netG, netD, netC, image_encoder = netG.float(), netD.float(), netC.float(), image_encoder.float()

    # Flags to detect model type
    use_san = True if "SALIP" in args.model else False
    is_gsw = True if "GSW" in args.model else False
    is_multi = True if "multi" in args.model else False
    use_mi = ("MI" in args.model) or is_multi

    sim_w = args.sim_w
    args.current_sim_w = sim_w

    if epoch < args.fid_warmup:
        args.c_fid_now = torch.linspace(0.0, args.c_fid, args.fid_warmup)[epoch]
        print(f"c_fid: {args.c_fid_now}")
    else:
        args.c_fid_now = args.c_fid

    for step, data in enumerate(dataloader, 0):
        # with torch.autograd.detect_anomaly():
        ##############
        # Train D
        ##############
        if step+1 > len(dataloader):
            break
        # dist.barrier()
        optimizerD.zero_grad()
        # with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc:
        with torch.cuda.amp.autocast() if args.mixed_precision else accelerator.autocast() if args.accelerator == "hug" else dummy_context_mgr() as mpc:
            # # prepare_data
            real, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(
                data, text_encoder, device)
            sent_emb_clean = sent_emb
            sent_emb_clean = sent_emb_clean.requires_grad_()

            real = real.requires_grad_()
            words_embs = words_embs.requires_grad_()

            # predict real
            CLIP_real, real_emb = image_encoder(real)    # CLIP_real: extracted from several layers in CLIP    real_emb: CLIP embedding of the image
            real_feats = netD(CLIP_real)

            if not args.mixed_precision:    # CLIP is operated in half precision
                real_feats = real_feats.float()
                sent_emb = sent_emb.float()
                sent_emb_clean = sent_emb_clean.float()

            # synthesize fake images
            noise = torch.randn(batch_size, z_dim).to(device)
            fake = netG(noise, sent_emb)
            CLIP_fake, fake_emb = image_encoder(fake)
            fake_feats = netD(CLIP_fake.detach())

            pred_real, errD_real_fake = predict_loss_dis(
                netC, real_feats, fake_feats, sent_emb)
            
            mi_loss_d = 0.
            
        # MA-GP
        if args.mixed_precision:
            errD_MAGP = MA_GP_MP(CLIP_real, sent_emb, pred_real, scaler_D, args)
        else:
            with accelerator.autocast() if args.accelerator == "hug" else dummy_context_mgr() as mpc:
                errD_MAGP = MA_GP_FP32(CLIP_real, sent_emb, pred_real, args)


        # whole D loss
        # with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc:
        with torch.cuda.amp.autocast() if args.mixed_precision else accelerator.autocast() if args.accelerator == "hug" else dummy_context_mgr() as mpc:
            errD = errD_real_fake + errD_MAGP

        # update D
        if args.mixed_precision:
            if args.accelerator == "hug":
                accelerator.backward(scaler_D.scale(errD))
            else:
                scaler_D.scale(errD).backward()
            scaler_D.step(optimizerD)
            scaler_D.update()
            if scaler_D.get_scale() < args.scaler_min:
                scaler_G.update(args.scaler_min)
        else:
            if args.accelerator == "hug":
                if accelerator.scaler is not None:
                    if (accelerator.scaler.get_scale() < args.scaler_min) and args.mixed_precision:
                        accelerator.scaler.update(torch.tensor(args.scaler_min, dtype=torch.float))
                accelerator.backward(errD)
            else:
                errD.backward()
            optimizerD.step()
            
        dist.barrier()

        ##############
        # Train G
        ##############
        with dummy_context_mgr() as mpc0:    # without detect_anomaly
            optimizerG.zero_grad()
            with torch.cuda.amp.autocast() if args.mixed_precision else accelerator.autocast() if args.accelerator == "hug" else dummy_context_mgr() as mpc:
                real_feats = netD(CLIP_real.detach())
                fake_feats = netD(CLIP_fake)
                errD_main = predict_loss_gen(netC, real_feats, fake_feats, sent_emb)

                _, fake_emb_eval = image_encoder(fake, eval=True)
                if args.img_img_sim:
                    _, real_emb_eval = image_encoder(real, eval=True)
                    text_img_sim = torch.cosine_similarity(fake_emb_eval, real_emb_eval.detach()).mean()   # similarity between real img emb and fake img emb
                else:
                    text_img_sim = torch.cosine_similarity(fake_emb_eval, sent_emb).mean()


                errG = errD_main - sim_w*text_img_sim 
                
            # dist.barrier()
            if args.mixed_precision:
                if args.accelerator == "hug":
                    accelerator.backward(scaler_G.scale(errG))
                else:
                    scaler_G.scale(errG).backward()
                scaler_G.step(optimizerG)
                scaler_G.update()
                if scaler_G.get_scale() < args.scaler_min:
                    # scaler_G.update(16384.0)
                    scaler_G.update(args.scaler_min)
            else:
                if args.accelerator == "hug":
                    accelerator.wait_for_everyone()
                    if accelerator.scaler is not None:
                        if (accelerator.scaler.get_scale() < args.scaler_min) and args.mixed_precision:
                            accelerator.scaler.update(torch.tensor(args.scaler_min, dtype=torch.float))
                    accelerator.backward(errG)
                else:
                    errG.backward()
                replace_nan_grad(netG)    # replace nan, inf and -inf to normal values
                optimizerG.step()
        
        dist.barrier()

        # update loop information
        with torch.no_grad():
            if (args.multi_gpus == True) and (get_rank() != 0):
                pass
            else:
                _skip = False
                if args.accelerator == "hug":
                    _skip = False if accelerator.is_main_process else True

                if not _skip:
                    if step%10 == 0:
                        writer.add_scalar("Loss/D", errD.item(), (epoch-1)*len(dataloader)+step)
                        writer.add_scalar("Loss/G", errG.item(), (epoch-1)*len(dataloader)+step)
                        writer.add_scalar("Loss/D_main", errD_real_fake.item(), (epoch-1)*len(dataloader)+step)
                        writer.add_scalar("Loss/D_MAGP", errD_MAGP.item(), (epoch-1)*len(dataloader)+step)
                        writer.add_scalar("Loss/text_img_sim", text_img_sim.item(), (epoch-1)*len(dataloader)+step)


def test(dataloader, text_encoder, netG, PTM, device, m1, s1, epoch, max_epoch, times, z_dim, batch_size, args):
    FID, TI_sim = calculate_FID_CLIP_sim(
        dataloader, text_encoder, netG, PTM, device, m1, s1, epoch, max_epoch, times, z_dim, batch_size, args)
    return FID, TI_sim



#########   MAGP   ########
def MA_GP_MP(img, sent, out, scaler, args, acc=None):
    grads = torch.autograd.grad(outputs=scaler.scale(out),
                                inputs=(img, sent),
                                grad_outputs=torch.ones_like(out),
                                retain_graph=True,
                                create_graph=True,
                                only_inputs=True)
    if acc is None:
        inv_scale = 1./(scaler.get_scale()+float("1e-8"))
        #inv_scale = 1./scaler.get_scale()
        grads = [grad * inv_scale for grad in grads]
        with torch.cuda.amp.autocast():
            grad0 = grads[0].view(grads[0].size(0), -1)
            grad1 = grads[1].view(grads[1].size(0), -1)
            grad = torch.cat((grad0, grad1), dim=1)
            grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
            d_loss_gp = args.k_gp * torch.mean((grad_l2norm) ** args.p_gp)
    else:
        with acc.autocast():
            grad0 = grads[0].view(grads[0].size(0), -1)
            grad1 = grads[1].view(grads[1].size(0), -1)
            grad = torch.cat((grad0, grad1), dim=1)
            grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
            d_loss_gp = args.k_gp * torch.mean((grad_l2norm) ** args.p_gp)
    return d_loss_gp


def MA_GP_FP32(img, sent, out, args):
    grads = torch.autograd.grad(outputs=out,
                                inputs=(img, sent),
                                grad_outputs=torch.ones(out.size()).cuda(),
                                retain_graph=True,
                                create_graph=True,
                                only_inputs=True)
    grad0 = grads[0].view(grads[0].size(0), -1)
    grad1 = grads[1].view(grads[1].size(0), -1)
    grad = torch.cat((grad0, grad1), dim=1)
    grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
    d_loss_gp = args.k_gp * torch.mean((grad_l2norm) ** args.p_gp)
    return d_loss_gp



def sample(dataloader, netG, text_encoder, save_dir, device, multi_gpus, z_dim, stamp):
    netG.eval()
    for step, data in enumerate(dataloader, 0):
        ######################################################
        # (1) Prepare_data
        ######################################################
        real, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(
            data, text_encoder, device)
        ######################################################
        # (2) Generate fake images
        ######################################################
        batch_size = min(sent_emb.size(0), 64)    # restrict for large batch size
        sent_emb = sent_emb[:batch_size]
        captions = captions[:batch_size]
        with torch.no_grad():
            noise = torch.randn(batch_size, z_dim).to(device)
            fake_imgs = netG(noise, sent_emb, eval=True).float()
            fake_imgs = torch.clamp(fake_imgs, -1., 1.)
            if multi_gpus == True:
                batch_img_name = 'step_%04d.png' % (step)
                batch_img_save_dir = osp.join(
                    save_dir, 'batch', str('gpu%d' % (get_rank())), 'imgs')
                batch_img_save_name = osp.join(
                    batch_img_save_dir, batch_img_name)
                batch_txt_name = 'step_%04d.txt' % (step)
                batch_txt_save_dir = osp.join(
                    save_dir, 'batch', str('gpu%d' % (get_rank())), 'txts')
                batch_txt_save_name = osp.join(
                    batch_txt_save_dir, batch_txt_name)
            else:
                batch_img_name = 'step_%04d.png' % (step)
                batch_img_save_dir = osp.join(save_dir, 'batch', 'imgs')
                batch_img_save_name = osp.join(
                    batch_img_save_dir, batch_img_name)
                batch_txt_name = 'step_%04d.txt' % (step)
                batch_txt_save_dir = osp.join(save_dir, 'batch', 'txts')
                batch_txt_save_name = osp.join(
                    batch_txt_save_dir, batch_txt_name)
            mkdir_p(batch_img_save_dir)
            vutils.save_image(fake_imgs.data, batch_img_save_name,
                              nrow=8, value_range=(-1, 1), normalize=True)
            mkdir_p(batch_txt_save_dir)
            txt = open(batch_txt_save_name, 'w')
            for cap in captions:
                txt.write(cap+'\n')
            txt.close()
            for j in range(batch_size):
                im = fake_imgs[j].data.cpu().numpy()
                # [-1, 1] --> [0, 255]
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                im = np.transpose(im, (1, 2, 0))
                im = Image.fromarray(im)
                ######################################################
                # (3) Save fake images
                ######################################################
                if multi_gpus == True:
                    single_img_name = 'batch_%04d.png' % (j)
                    single_img_save_dir = osp.join(save_dir, 'single', str(
                        'gpu%d' % (get_rank())), 'step%04d' % (step))
                    single_img_save_name = osp.join(
                        single_img_save_dir, single_img_name)
                else:
                    single_img_name = 'step_%04d.png' % (step)
                    single_img_save_dir = osp.join(
                        save_dir, 'single', 'step%04d' % (step))
                    single_img_save_name = osp.join(
                        single_img_save_dir, single_img_name)
                mkdir_p(single_img_save_dir)
                im.save(single_img_save_name)
        if (multi_gpus == True) and (get_rank() != 0):
            pass
        else:
            print('Step: %d' % (step))


def calculate_FID_CLIP_sim(dataloader, text_encoder, netG, CLIP, device, m1, s1, epoch, max_epoch, times, z_dim, batch_size, args):
    """ Calculates the FID """
    clip_cos = torch.FloatTensor([0.0]).to(device)
    # prepare Inception V3
    dims = 2048
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx])
    model.to(device)
    model.eval()
    netG.eval()
    norm = transforms.Compose([
        transforms.Normalize((-1, -1, -1), (2, 2, 2)),
        transforms.Resize((299, 299)),
    ])
    n_gpu = dist.get_world_size() if args.accelerator in ["hug", "ddp"] else 1
    dl_length = dataloader.__len__()
    imgs_num = dl_length * n_gpu * batch_size * times
    print(f"Evaluating scores with {imgs_num} imgs...")
    pred_arr = np.empty((imgs_num, dims))
    if (n_gpu != 1) and (get_rank() != 0):
        pass
    else:
        loop = tqdm(total=int(dl_length*times))
    for time in range(times):
        for i, data in enumerate(dataloader):
            start = i * batch_size * n_gpu + time * dl_length * n_gpu * batch_size
            end = start + batch_size * n_gpu
            ######################################################
            # (1) Prepare_data
            ######################################################
            imgs, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(
                data, text_encoder, device)
            ######################################################
            # (2) Generate fake images
            ######################################################
            batch_size = sent_emb.size(0)
            netG.eval()
            with torch.no_grad():
                noise = torch.randn(batch_size, z_dim).to(device)
                fake_imgs = netG(noise, sent_emb, eval=True).float()
                # norm_ip(fake_imgs, -1, 1)
                fake_imgs = torch.clamp(fake_imgs, -1., 1.)
                fake_imgs = torch.nan_to_num(
                    fake_imgs, nan=-1.0, posinf=1.0, neginf=-1.0)
                clip_sim = calc_clip_sim(CLIP, fake_imgs, CLIP_tokens, device)
                clip_cos = clip_cos + clip_sim
                fake = norm(fake_imgs)
                pred = model(fake)[0]
                if pred.shape[2] != 1 or pred.shape[3] != 1:
                    pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
                # concat pred from multi GPUs
                if n_gpu > 1:
                    output = list(torch.empty_like(pred) for _ in range(n_gpu))
                    dist.barrier()
                    dist.all_gather(output, pred)
                    pred_all = torch.cat(output, dim=0).squeeze(-1).squeeze(-1)
                else:
                    pred_all = pred.squeeze(-1).squeeze(-1)
                pred_arr[start:end] = pred_all.cpu().data.numpy()
            # update loop information
            if (n_gpu != 1) and (get_rank() != 0):
                pass
            else:
                loop.update(1)
                if epoch == -1:
                    loop.set_description('Evaluating]')
                else:
                    loop.set_description(f'Eval Epoch [{epoch}/{max_epoch}]')
                loop.set_postfix()
    if (n_gpu != 1) and (get_rank() != 0):
        pass
    else:
        loop.close()
    # CLIP-score
    if n_gpu > 1:
        CLIP_score_gather = list(torch.empty_like(clip_cos) for _ in range(n_gpu))
        dist.barrier()
        dist.all_gather(CLIP_score_gather, clip_cos)
        clip_score = torch.cat(
            CLIP_score_gather, dim=0).mean().item()/(dl_length*times)
    else:
        clip_score = clip_cos.mean().item()/(dl_length*times)
    # FID
    m2 = np.mean(pred_arr, axis=0)
    s2 = np.cov(pred_arr, rowvar=False)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)
    return fid_value, clip_score


def calc_clip_sim(clip, fake, caps_clip, device):
    ''' calculate cosine similarity between fake and text features,
    '''
    # Calculate features
    fake = transf_to_CLIP_input(fake)
    fake_features = clip.encode_image(fake)
    text_features = clip.encode_text(caps_clip)
    text_img_sim = torch.cosine_similarity(fake_features, text_features).mean()
    return text_img_sim


def sample_one_batch(noise, sent, netG, multi_gpus, epoch, img_save_dir, args):
    if (args.accelerator in ["ddp", "hug"]) and (get_rank() != 0):
        pass
    else:
        netG.eval()
        with torch.no_grad():
            B = noise.size(0)
            fixed_results_train = generate_samples(
                noise[:B//2], sent[:B//2], netG).cpu()
            # torch.cuda.empty_cache()
            fixed_results_test = generate_samples(
                noise[B//2:], sent[B//2:], netG).cpu()
            # torch.cuda.empty_cache()
            fixed_results = torch.cat(
                (fixed_results_train, fixed_results_test), dim=0)
        img_name = 'samples_epoch_%03d.png' % (epoch)
        img_save_path = osp.join(img_save_dir, img_name)
        vutils.save_image(fixed_results.data, img_save_path,
                          nrow=8, value_range=(-1, 1), normalize=True)
        from_one_prompt = torch.cat([generate_samples(
            noise[:B//4], sent[i].expand(B//4,-1), netG).cpu() for i in range(4)], dim=0)
        img_name = 'samples_variation_epoch_%03d.png' % (epoch)
        img_save_path = osp.join(img_save_dir, img_name)
        vutils.save_image(from_one_prompt.data, img_save_path,
                          nrow=8, value_range=(-1, 1), normalize=True)


def generate_samples(noise, caption, model):
    with torch.no_grad():
        fake = model(noise, caption, eval=True)
    return fake


def predict_loss_dis(predictor, img_feature_real, img_feature_fake, text_feature):
    output_real = predictor(img_feature_real, text_feature, flg_train=True)
    output_fake = predictor(img_feature_fake, text_feature, flg_train=True)
    loss = sona_loss_dis(output_real, output_fake)
    disc_real = output_real["f_disc"] + F.relu(output_real["f_algn"])
    return disc_real, loss

def predict_loss_gen(predictor, img_feature_real, img_feature_fake, text_feature):
    output_real = predictor(img_feature_real, text_feature, flg_train=True)
    output_fake = predictor(img_feature_fake, text_feature, flg_train=True)
    loss = sona_loss_gen(output_real, output_fake)
    return loss


def sona_loss_dis(real, fake):
    # Extract
    b_0 = real["bias"][:1] if real["bias"].numel()>1 else real["bias"]
    s_0, s_1, s_2 = real["scales"][:3]
    # Direction
    loss_d = torch.mean(fake["f_disc_dir"] - real["f_disc_dir"])
    # Uncondiotional
    loss_u = F.softplus(s_0 * (fake["f_disc"] - b_0)).mean() / s_0 + \
             F.softplus(s_0 * (b_0 - real["f_disc"])).mean() / s_0
    loss_c = F.softplus(s_1 * ((fake["f_algn"] + fake["f_disc"]) - \
                        (real["f_algn"] + real["f_disc"]))).mean() / s_1
    loss_m = F.softplus(s_2 * ((real["f_algn_mis"] + real["f_disc_mis"]) - \
                        (real["f_algn"] + real["f_disc"]))).mean() / s_2
    return loss_d + loss_u + loss_c + loss_m


def sona_loss_gen(real, fake):
    loss_u = -torch.mean(fake["f_disc"])
    loss_c = F.softplus((real["f_algn"] + real["f_gen"]) - \
                        (fake["f_algn"] + fake["f_gen"])).mean()
    dummy = real["bias"].sum() + real["scales"].sum()
    return loss_u + loss_c + 0.0 * dummy

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2
    '''
    print('&'*20)
    print(sigma1)#, sigma1.type())
    print('&'*20)
    print(sigma2)#, sigma2.type())
    '''
    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) +
            np.trace(sigma2) - 2 * tr_covmean)

def replace_nan_grad(model):
    for p in model.parameters():
        if p.grad is not None:
            torch.nan_to_num(p.grad, nan=0, posinf=1e5, neginf=-1e5, out=p.grad)
