import torch
from torch import optim
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms, utils
from torchvision.datasets import ImageFolder

from tqdm import tqdm
from collections import defaultdict
from copy import deepcopy
import random, os

from models.model_pose_identity import DiscriminatorAE, GeneratorPI, StyleEncoder
from models.model_pose_identity import DiscriminatorHR, GeneratorHR

from utilstrain.options import parser
from utilstrain.distributed import get_rank, synchronize
from utilstrain.non_leaking import augment, AdaptiveAugment
from utilstrain.utils import d_logistic_loss, d_r1_loss, g_nonsaturating_loss, g_path_regularize
from utilstrain.utils import CustomDDP, ddp_runner
from utilstrain.utils import resize, get_dir, idx2image, data_sampler, sample_data, make_noise, mixing_noise, requires_grad, accumulate
from utilstrain.utils import scale_flip_augment, random_cutout_augment

from utilseval import lpips
from utilseval.generate import generate_matrix




def train(args, loaders, percept, generator, discriminator, style_encoder, g_optim, d_optim, e_optim, g_ema, se_ema, device):
    
    print("start training ...")
    percept.model.net.to('cuda:%d'%(get_rank()))

    loader_org = sample_data(loaders[0])
    if len(loaders)==2: loader_crop = sample_data(loaders[1])
    else: loader_crop = loader_org

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    loss_dict = defaultdict(lambda:0)

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module
        e_module = style_encoder.module
    else:
        g_module = generator
        d_module = discriminator
        e_module = style_encoder

    accum = 0.5 ** (32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0

    ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8, device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break
        
        if random.randint(0,2) in [0,1]:
            real_img = next(loader_crop)[0].to(device)
        else:
            real_img = next(loader_org)[0].to(device)
        if idx==1:
            utils.save_image(real_img, f"experiments/{args.name}/sample/real.jpg",
                                                        normalize=True, value_range=(-1, 1) )

        #****** 1. Train the Discriminator ******#
        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img = generator(noise)[0]

        if args.augment:
            real_img_aug, _ = augment( scale_flip_augment(real_img)[0] , ada_aug_p )
            fake_img_aug, _ = augment(fake_img, ada_aug_p)
        else:
            real_img_aug = real_img
            fake_img_aug = fake_img

        fake_pred = discriminator(fake_img_aug)

        real_pred, real_rec_img = discriminator(real_img_aug, ae=True)
        d_loss = d_logistic_loss(real_pred, fake_pred)
        d_ae_loss = percept( real_rec_img, F.interpolate(real_img_aug, 128) ).sum() +\
                            F.mse_loss(real_rec_img, F.interpolate(real_img_aug, 128) )
        if get_rank() == 0:                    
            loss_dict['d_ae'] += d_ae_loss.item()

        discriminator.zero_grad()
            
        d_loss += d_ae_loss
        d_loss.backward()
        d_optim.step()

        if idx%500==0:
            utils.save_image( torch.cat([F.interpolate(real_img_aug, 128), real_rec_img]),
                                f"experiments/{args.name}/sample/rec_{str(i//10000).zfill(6)}.jpg",
                                nrow=8, normalize=True, value_range=(-1, 1) )
        if get_rank() == 0:
            loss_dict["d"] = d_loss.item()
            loss_dict["real_score"] += real_pred.mean().item()
            loss_dict["fake_score"] += fake_pred.mean().item()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment( scale_flip_augment( real_img )[0], ada_aug_p)
            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()

            d_optim.step()

        if get_rank() == 0:
            loss_dict["r1"] += r1_loss.item()

        #****** 2. Train the Generator ******#
        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, g_latents, diff, embed_idxs = generator(noise, return_latents=True)

        if args.augment:
            fake_img_aug, _ = augment( scale_flip_augment( fake_img )[0], ada_aug_p)
        else:
            fake_img_aug = fake_img

        fake_pred = discriminator(fake_img_aug)
        g_loss = g_nonsaturating_loss(fake_pred) + diff
        
        if get_rank() == 0:
            loss_dict["g"] += g_loss.item()

        generator.zero_grad()
        
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0
        if args.size > 256: g_regularize = False
        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing, device)
            fake_img, latents = generator(noise, return_latents=True)[:2]

            path_loss, mean_path_length, _ = g_path_regularize(
                fake_img, latents, mean_path_length
            )

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

        accumulate(g_ema, g_module, accum)
        
        #****** 3. Train the Style_Encoder together with Generator ******#
        ### 3.1 from random noise, tran the Encoder to predict the G latents
        noises = make_noise(args.batch, args.latent, 3, device)
        g_latents = [generator.style(n) for n in noises]
        mix_1 = random.randint(1, dislow_idx)
        mix_2 = random.randint(dishigh_idx, generator.n_latent-2 )
        gl_low  = g_latents[0].unsqueeze(1).repeat(1, mix_1, 1)
        gl_mid  = g_latents[1].unsqueeze(1).repeat(1, mix_2-mix_1, 1)
        gl_high = g_latents[2].unsqueeze(1).repeat(1, generator.n_latent-mix_2, 1)
        g_latents = torch.cat( [gl_low, gl_mid, gl_high] , dim=1 ).detach()
        fake_img, diff_1, _ = ddp_runner( generator, "decode", g_latents, random_noise=False)

        pred_fake, fake_features = discriminator(fake_img, extract=True)
        pred_g_styles = style_encoder(fake_features)
        g_fake_loss = g_nonsaturating_loss(pred_fake)
        style_rec_loss = F.mse_loss(pred_g_styles, g_latents)
        g1_loss = g_fake_loss + style_rec_loss + diff_1

        ### 3.2 from real image, train the G and Encoder to reconstruct the image
        real_img = scale_flip_augment(real_img)[0]
        with torch.no_grad():
            real_features = ddp_runner( discriminator, "extract", real_img)
        pred_r_styles = style_encoder(real_features)

        if random.randint(0,1)==0:
            real_crop_img = random_cutout_augment(real_img)
            with torch.no_grad():
                real_crop_features = ddp_runner( discriminator, "extract", real_crop_img)
            pred_r_crop_styles = style_encoder(real_crop_features)
            pred_r_styles = torch.cat([ pred_r_crop_styles[:,:dislow_idx], pred_r_styles[:,dislow_idx:dishigh_idx], pred_r_crop_styles[:,dishigh_idx:] ], dim=1)
        
        rec_real_img, diff_2, _ = ddp_runner( generator, "decode", pred_r_styles, random_noise=True)
        
        rec_real_img_aug = augment( rec_real_img , ada_aug_p)[0] if args.augment else rec_real_img
        pred_rec = discriminator(rec_real_img_aug)
        style_real_loss = g_nonsaturating_loss(pred_rec)

        real_rec_loss = percept(    resize(rec_real_img, args.size//rec_factor), resize(real_img, args.size//rec_factor) ).sum() +\
                        F.mse_loss( resize(rec_real_img, args.size//rec_factor), resize(real_img, args.size//rec_factor) )
        
        g_losses = g1_loss + style_real_loss + real_rec_loss + diff_2
        
        generator.zero_grad()
        style_encoder.zero_grad()
        g_losses.backward()
        g_optim.step()
        e_optim.step()
        accumulate(g_ema, g_module, accum)
        accumulate(se_ema, e_module, accum)
         
        if get_rank() == 0:
            loss_dict["style_rec"] += style_rec_loss.item()
            loss_dict["real_rec"] += real_rec_loss.item()
            loss_dict["vq_diff"] += diff_2.item() + diff_1.item()
        
        ### 3.3 real image scaling and flipping, force the certain layers learn the augment-related features
        if random.randint(0,2) in [0,1]:
            real_img = next(loader_crop)[0].to(device)
        else:
            real_img = next(loader_org)[0].to(device)
            
        real_img, real_flip = scale_flip_augment(real_img) 
        if random.randint(0,1)==0:
            real_img = random_cutout_augment(real_img)
            
        with torch.no_grad():
            rf = ddp_runner( discriminator, "extract", real_img)
            rf_flip = ddp_runner( discriminator, "extract", real_flip )
        
        r_styles = style_encoder(rf)
        r_flip_styles = style_encoder(rf_flip)

        mix_styles = torch.cat([ r_styles[:,:dislow_idx], r_flip_styles[:,dislow_idx:dishigh_idx], r_styles[:,dishigh_idx:] ], dim=1)
        rec_real_flip_img, diff_3, embed_idxs = ddp_runner( generator, "decode", mix_styles, random_noise=True)
        
        rec_real_flip_img_aug = augment(rec_real_flip_img, ada_aug_p)[0] if args.augment else rec_real_flip_img
        pred_rec = discriminator(rec_real_flip_img_aug)   
        style_real_flip_loss = g_nonsaturating_loss(pred_rec)

        real_flip_rec_loss = percept(   resize(rec_real_flip_img, args.size//rec_factor), resize(real_flip, args.size//rec_factor)).sum() +\
                            F.mse_loss( resize(rec_real_flip_img, args.size//rec_factor), resize(real_flip, args.size//rec_factor))
        g_losses = style_real_flip_loss + real_flip_rec_loss + diff_3

        generator.zero_grad()
        style_encoder.zero_grad()
        g_losses.backward()
        g_optim.step()
        e_optim.step()
        accumulate(g_ema, g_module, accum)
        accumulate(se_ema, e_module, accum)
        
        #****** 4. Train the Discriminator on rec images, to make sure D is strong enough ******#
        requires_grad(discriminator, True)
        requires_grad(generator, False)

        fake_img = rec_real_flip_img_aug.detach() if random.randint(0, 1)==0 else rec_real_img_aug.detach()
        real_img_d = augment( real_flip , ada_aug_p)[0] if args.augment else real_flip
        fake_pred = discriminator(fake_img.detach())
        real_pred = discriminator(real_img_d.detach())

        d_loss = d_logistic_loss(real_pred, fake_pred)

        discriminator.zero_grad()   
        d_loss.backward()
        d_optim.step()

        ### Print logging and save image and checkpoints
        if get_rank() == 0:
            if idx%50==0:
                d_loss_val = loss_dict["d"] / 50
                g_loss_val = loss_dict["g"] / 50
                r1_val = loss_dict["r1"] / 50
                g_style_rec_val = loss_dict["style_rec"] / 50
                real_img_rec_val = loss_dict["real_rec"] / 50
                d_ae_val = loss_dict["d_ae"] / 50
                vq_diff_val = loss_dict["vq_diff"] / 50
                loss_dict = defaultdict(lambda:0)

                print(
                    f"d: {d_loss_val:.4f};  g: {g_loss_val:.4f};  r1: {r1_val:.4f};  augment: {ada_aug_p:.4f};  vq_diff: {vq_diff_val:.4f} \n"
                    f"g_style_rec: {g_style_rec_val: .4f};  real_img_rec: {real_img_rec_val: .4f};  d_ae:   {d_ae_val: .4f} "
                )

            if idx % 100 == 0:
                with torch.no_grad():
                    
                    idx_images = [real_img.add(1).mul(0.5).cpu(), \
                                  real_flip.add(1).mul(0.5).cpu(), \
                                  rec_real_flip_img.add(1).mul(0.5).cpu()]
                    for eidx, embed in enumerate( embed_idxs ):
                        idx_images.append(F.interpolate( idx2image(embed, num_colors=n_embed[eidx]) , args.size ) )
                    
                    utils.save_image( torch.cat(idx_images), f"experiments/{args.name}/sample/flip_{str(i//1000 * 1000).zfill(6)}.jpg", nrow=args.batch,)

                    generate_matrix(discriminator, g_ema, se_ema, real_flip, \
                            f"experiments/{args.name}/sample/g_{str(i//1000 * 1000).zfill(6)}.jpg",
                            mix_idx=dislow_idx, cut_idx=dishigh_idx, n_embed=n_embed)

            if idx % 1000 == 0 or idx==args.iter-1:
                torch.save( {   
                        "g": g_module.state_dict(), "d": d_module.state_dict(), 'e': e_module.state_dict(),
                        "g_ema": g_ema.state_dict(), "se_ema": se_ema.state_dict(),
                        "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), 'e_optim': e_optim.state_dict(),
                        "args": args, "ada_aug_p": ada_aug_p, },
                        f"experiments/{args.name}/checkpoint/{str((i+1)//10000 * 10000).zfill(6)}.pt",)

if __name__ == "__main__":
    
    device = "cuda"
    args = parser.parse_args()

    n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    args.distributed = n_gpu > 1
    args.n_gpu = n_gpu
    
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    args.latent = 512
    args.n_mlp = 8
    rec_factor = args.rec_factor
    dislow_idx = args.dislow
    dishigh_idx = args.dishigh
    n_embed = args.vq_emb
    if isinstance(n_embed, int):
        n_embed = [n_embed]*(dishigh_idx-dislow_idx)
    args.vq_emb = n_embed

    args.start_iter = 0
    
    if args.size==256:
        generator = GeneratorPI( args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier, 
                                dislow=dislow_idx, dishigh=dishigh_idx, n_embed=n_embed, vq=args.vq_type ).to(device)
        discriminator = DiscriminatorAE( args.size, channel_multiplier=args.channel_multiplier ).to(device)
        style_encoder = StyleEncoder( args.size, args.latent, channel_multiplier=args.channel_multiplier ).to(device)
    elif args.size>256:
        generator = GeneratorHR( args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier, 
                                dislow=dislow_idx, dishigh=dishigh_idx, n_embed=n_embed, vq=args.vq_type ).to(device)
        discriminator = DiscriminatorHR( args.size, channel_multiplier=args.channel_multiplier ).to(device)
        style_encoder = StyleEncoder( 256, args.latent, channel_multiplier=args.channel_multiplier ).to(device)

    if args.size>256 and args.ckpt_256 != None:
        print("load pre-trained 256 model:", args.ckpt_256)
        ckpt = torch.load(args.ckpt_256, map_location=lambda storage, loc: storage)
        generator.gen_256.load_state_dict(ckpt["g_ema"])
        discriminator.dis_256.load_state_dict(ckpt["d"])
        style_encoder.load_state_dict(ckpt['se_ema'])

    g_ema = deepcopy(generator)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    se_ema = deepcopy(style_encoder)
    se_ema.eval()
    accumulate(se_ema, style_encoder, 0)

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    if args.size==256:
        g_optim = optim.Adam( generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio) )
        d_optim = optim.Adam( discriminator.parameters(), lr=args.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio) )
        e_optim = optim.Adam( style_encoder.parameters(), lr = args.lr, betas=(0, 0.99) )
    elif args.size>256:
        ### setting up what parts of each model needs to be trained
        g_optim = optim.Adam( generator.gen_256.parameters(), lr=args.lr * g_reg_ratio / 5, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio) )
        g_optim.add_param_group({"params":generator.convs.parameters(), 'lr': args.lr * g_reg_ratio})
        g_optim.add_param_group({"params":generator.to_rgbs.parameters(), 'lr': args.lr * g_reg_ratio})
        generator.gen_256.set_vq_decay(0.9997) # set vq modules to be trained lower

        d_optim = optim.Adam( discriminator.dis_256.parameters(), lr=args.lr * d_reg_ratio / 5, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio) )
        d_optim.add_param_group({"params":discriminator.extra_convs.parameters(), 'lr': args.lr * g_reg_ratio})

        e_optim = optim.Adam( style_encoder.parameters(), lr = args.lr/5, betas=(0, 0.99) )

    percept = lpips.PerceptualLoss( model="net-lin", net="vgg", use_gpu=True )

    if args.ckpt is not None:
        print("load model:", args.ckpt)
        ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)

        try:
            ckpt_name = os.path.basename(args.ckpt)
            args.start_iter = int(os.path.splitext(ckpt_name)[0])
        except ValueError:
            pass

        generator.load_state_dict(ckpt["g"])
        discriminator.load_state_dict(ckpt["d"])
        style_encoder.load_state_dict(ckpt['e'])
        g_ema.load_state_dict(ckpt["g_ema"])
        se_ema.load_state_dict(ckpt["se_ema"])

        g_optim.load_state_dict(ckpt["g_optim"])
        d_optim.load_state_dict(ckpt["d_optim"])
        e_optim.load_state_dict(ckpt['e_optim'])
      

    if args.distributed:
        generator = CustomDDP( generator,
            device_ids=[args.local_rank], output_device=args.local_rank,
            broadcast_buffers=False, find_unused_parameters=True
        )

        discriminator = CustomDDP( discriminator,
            device_ids=[args.local_rank], output_device=args.local_rank,
            broadcast_buffers=False, find_unused_parameters=True
        )

        style_encoder = CustomDDP( style_encoder,
            device_ids=[args.local_rank], output_device=args.local_rank,
            broadcast_buffers=False, find_unused_parameters=False
        )
    
    
    transform_org = transforms.Compose( [
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.size, args.size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),])
    
    dataset_org = ImageFolder(args.path, transform=transform_org)
    loader_org = data.DataLoader( dataset_org, batch_size=args.batch,
        sampler=data_sampler(dataset_org, shuffle=True, distributed=args.distributed),
        drop_last=True, num_workers=4 )

    loaders = [loader_org]

    if args.crop_aug:
        transform_crop = transforms.Compose( [
                transforms.RandomHorizontalFlip(),
                transforms.Resize((int(1.2*args.size), int(1.2*args.size))),
                transforms.RandomCrop((args.size, args.size)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),])
        
        dataset_crop = ImageFolder(args.path, transform=transform_crop)
        loader_crop = data.DataLoader( dataset_crop, batch_size=args.batch,
            sampler=data_sampler(dataset_crop, shuffle=True, distributed=args.distributed),
            drop_last=True, num_workers=4 )
        loaders.append(loader_crop)

    get_dir(args)
    train(args, loaders, percept, generator, discriminator, style_encoder, g_optim, d_optim, e_optim, g_ema, se_ema, device)
