import argparse
import torch
import os
import json
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torchvision
from datetime import datetime
from models import get_model
from utils import *
from data.dataset import build_boundary_distribution


def train(args):
    # save configurations
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    batch_size = args.batch_size
    nz = args.nz

    device = 'cuda:0'
    args.device = device
    
    # Get Networks/Optimizer
    netD, netG1 = get_model(args)
    _, netG2 = get_model(args)
    
    netG1 = netG1.to(device)
    optimizerG1 = optim.Adam(netG1.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
    if args.use_ema:
        optimizerG1 = EMA(optimizerG1, ema_decay=args.ema_decay)
    if args.lr_scheduler:
        schedulerG1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG1, args.num_iterations, eta_min=args.eta_min)
    netG1 = nn.DataParallel(netG1)

    netG2 = netG2.to(device)
    optimizerG2 = optim.Adam(netG2.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
    if args.use_ema:
        optimizerG2 = EMA(optimizerG2, ema_decay=args.ema_decay)
    if args.lr_scheduler:
        schedulerG2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG2, args.num_iterations, eta_min=args.eta_min)
    netG2 = nn.DataParallel(netG2)

    netD = netD.to(device)
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))    
    if args.lr_scheduler:
        schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_iterations, eta_min=args.eta_min)
    netD = nn.DataParallel(netD)

    
    # Get Data
    args.train = True
    source_dataset, target_dataset = build_boundary_distribution(args)

    # Sampler
    sampler = Sampler(args)
    phi1, phi2 = select_phi(args.phi1), select_phi(args.phi2)

    # Get Evaluation tools
    evaltool = EvalAdapted(args)
    logger = Logger(args, evaltool)

    # Start Initializing T
    netG1 = initialize(netG1, optimizerG1, source_dataset, target_dataset, args, logger.exp_path)
    netG2 = initialize(netG2, optimizerG2, target_dataset, source_dataset, args, logger.exp_path)

    # Start training
    start = datetime.now()
    for iter in range(args.num_iterations):
        #### Update potential ####
        for p in netD.parameters():
            p.requires_grad = True

        for _ in range(args.K_v):        
                 
            netD.zero_grad()

            # Sample x, t
            with torch.no_grad():
                x0 = source_dataset.sample().to(device)
                t = sampler.sample_t(batch_size).to(device)
                latent_z = torch.randn(batch_size, nz, device=device)
                x1 = netG1(x0, latent_z)
                xt = sampler(t, x0, x1)

            t.requires_grad = True
            xt.requires_grad = True

            Vx = netD(t, xt)
            cost = args.tau * t * torch.sum((x1 - x0).reshape(batch_size, -1)**2, dim=1)
            errDx = phi1(Vx - cost).mean()

            Vx_t = torch.autograd.grad(Vx.sum(), t, create_graph=True)[0]
            Vx_x = torch.autograd.grad(Vx.sum(), xt, create_graph=True)[0].reshape(batch_size, -1)
            HJBx = 2 *args.tau * Vx_t + 0.5 * torch.norm(Vx_x, dim=1)**2
            HJBx = HJBx.abs().mean()
            errDx = errDx + args.lmbda * HJBx

            errDx.backward()


            with torch.no_grad():
                y1 = target_dataset.sample().to(device)
                t = sampler.sample_t(batch_size).to(device)
                latent_z = torch.randn(batch_size, nz, device=device)
                y0 = netG2(y1, latent_z)
                yt = sampler(t, y0, y1)
                
            t.requires_grad = True
            yt.requires_grad = True

            cost = args.tau * (1 - t) * torch.sum((y1 - y0).reshape(batch_size, -1)**2, dim=1)            
            Vy = netD(t, yt)
            errDy = phi2(- Vy - cost).mean()

            Vy_t = torch.autograd.grad(Vy.sum(), t, create_graph=True)[0]
            Vy_y = torch.autograd.grad(Vy.sum(), yt, create_graph=True)[0].reshape(batch_size, -1)
            HJBy = 2 * args.tau * Vy_t + 0.5 * torch.norm(Vy_y, dim=1)**2
            HJBy = HJBy.abs().mean()
            errDy = errDy + args.lmbda * HJBy

            errDy.backward()

            if args.Dclip > 0:
                nn.utils.clip_grad_norm_(netD.parameters(), args.Dclip)

            optimizerD.step()


        #### Update Generator ####
        for p in netD.parameters():
            p.requires_grad = False

        for _ in range(args.K_T):
            netG1.zero_grad()

            t = sampler.sample_t(batch_size).to(device)
            x0 = source_dataset.sample().to(device)
            latent_z = torch.randn(batch_size, nz, device=device)
            x1 = netG1(x0, latent_z)
            xt = sampler(t, x0, x1)

            Vx = netD(t, xt)
            cost = args.tau * t * torch.sum((x1 - x0).reshape(batch_size, -1)**2, dim=1)
            errG1 = cost.mean() - Vx.mean()
            
            errG1.backward()
            if args.Gclip > 0:
                nn.utils.clip_grad_norm_(netG1.parameters(), args.Gclip)
            optimizerG1.step()


            netG2.zero_grad()

            t = sampler.sample_t(batch_size).to(device)
            y1 = target_dataset.sample().to(device)
            latent_z = torch.randn(batch_size, nz, device=device)
            y0 = netG2(y1, latent_z)
            yt = sampler(t, y0, y1)

            Vy = netD(t, yt)
            cost = args.tau * (1 - t) * torch.sum((y1 - y0).reshape(batch_size, -1)**2, dim=1)
            errG2 = cost.mean() + Vy.mean()
            
            errG2.backward()
            if args.Gclip > 0:
                nn.utils.clip_grad_norm_(netG2.parameters(), args.Gclip)
            optimizerG2.step()
            
        #### Update Schedulers
        if args.lr_scheduler:
            schedulerG1.step()
            schedulerG2.step()
            schedulerD.step()
        
    
        #### Visualizations and Save ####
        log = f'Iteration {iter + 1:07d} : G1 Loss {errG1.item():.4f}, G2 Loss {errG2.item():.4f}, Dx Loss {errDx.item():.4f}, Dy Loss {errDy.item():.4f}, HJBx Loss {HJBx.item():.4f}, HJBy Loss {HJBy.item():.4f}, Elapsed {datetime.now() - start}'
        logger(log)
        info = {'netG1': netG1, 'netG2': netG2, 'netD': netD, 'optimizerG1': optimizerG1, 'optimizerG2': optimizerG2, 'optimizerD': optimizerD}
        logger.save_image(info)        
        logger.save_ckpt(info)
        logger.calculate_fid(info)
        logger.step()


if __name__ == '__main__':
    parser = argparse.ArgumentParser('DIOTM Parameters')

    # Experiment description
    parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
    parser.add_argument('--exp', default='temp', help='name of the experiment')
    parser.add_argument('--resume', action='store_true',default=False, help='Resume training or not')
    parser.add_argument('--problem_name', default='uniform_to_cifar10', choices=[
                                                                 'uniform_to_cifar10', 
                                                                 'uniform_to_celeba64', 
                                                                 'uniform_to_celeba_256',
                                                                 'male_to_female',
                                                                 'handbags_to_shoes',
                                                                 'cat_to_dog',
                                                                 'wild_to_cat'
                                                                 ], help='name of dataset')
    parser.add_argument('--image_size', type=int, default=32, help='size of image (or data)')
    parser.add_argument('--num_channels', type=int, default=3, help='channel of image')
    
    # Network configurations
    parser.add_argument('--model_name', default='ncsnpp', choices=['ncsnpp', 'ddpm'], help='Choose default model')
    parser.add_argument('--centered', action='store_false', default=True, help='-1,1 scale')
    parser.add_argument('--num_channels_dae', type=int, default=128, help='number of initial channels in denoising model')
    parser.add_argument('--n_mlp', type=int, default=4, help='number of mlp layers for z')
    parser.add_argument('--ch_mult', nargs='+', type=int, default=[1,2,2,2], help='channel multiplier')
    parser.add_argument('--num_res_blocks', type=int, default=2, help='number of resnet blocks per scale')
    parser.add_argument('--attn_resolutions', default=(16,), help='resolution of applying attention')
    parser.add_argument('--dropout', type=float, default=0., help='drop-out rate')
    parser.add_argument('--resamp_with_conv', action='store_false', default=True, help='always up/down sampling with conv')
    parser.add_argument('--fir', action='store_false', default=True, help='FIR')
    parser.add_argument('--fir_kernel', default=[1, 3, 3, 1], help='FIR kernel')
    parser.add_argument('--skip_rescale', action='store_false', default=True, help='skip rescale')
    parser.add_argument('--resblock_type', default='biggan', help='tyle of resnet block, choice in biggan and ddpm')
    parser.add_argument('--progressive', type=str, default='none', choices=['none', 'output_skip', 'residual'], help='progressive type for output')
    parser.add_argument('--progressive_input', type=str, default='residual', choices=['none', 'input_skip', 'residual'], help='progressive type for input')
    parser.add_argument('--progressive_combine', type=str, default='sum', choices=['sum', 'cat'], help='progressive combine method.')
    parser.add_argument('--embedding_type', type=str, default='positional', choices=['positional', 'fourier'], help='type of time embedding')
    parser.add_argument('--fourier_scale', type=float, default=16., help='scale of fourier transform')
    parser.add_argument('--not_use_tanh', action='store_true', default=False, help='use tanh for last layer')
    parser.add_argument('--z_emb_dim', type=int, default=256, help='embedding dimension of z')
    parser.add_argument('--t_emb_dim', type=int, default=256, help='embedding dimension of z')
    parser.add_argument('--nz', type=int, default=100, help='latent dimension')
    parser.add_argument('--ngf', type=int, default=64, help='The default number of channels of model')
    
    # Training/Optimizer configurations
    parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
    parser.add_argument('--num_iterations', type=int, default=120000, help='the number of iterations')
    parser.add_argument('--lr_g', type=float, default=1.0e-4, help='learning rate g')
    parser.add_argument('--lr_d', type=float, default=1.0e-4, help='learning rate d')
    parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for adam')
    parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for adam')
    parser.add_argument('--use_ema', action='store_true', default=False, help='use EMA or not')
    parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
    parser.add_argument('--lr_scheduler', action='store_true', default=False, help='Use lr scheduler or not. We use cosine scheduler if the argument is activated.')
    parser.add_argument('--eta_min', type=float, default=1e-5, help='eta_min of lr_scheduler')
    parser.add_argument('--K_v', type=int, default=1)
    parser.add_argument('--K_T', type=int, default=1)
    parser.add_argument('--Dclip', type=float, default=0, help='Clip the gradient if the clip value is positive (>0)')
    parser.add_argument('--Gclip', type=float, default=0, help='Clip the gradient if the clip value is positive (>0)')
    
    # (ADD) Important Hyperparameters    
    parser.add_argument('--time_sample', type=str, default='uniform', choices=['uniform', 'beta51', 'beta22', 't0.5', 't0.8', 't0.9', 't1', 'discrete_2', 'discrete_4', 'discrete_8'], help='sampling t method')
    parser.add_argument('--lmbda', type=float, default=0, help='regularization parameter (HJB)')
    parser.add_argument('--tau', type=float, default=0.001, help='cost coefficient')
    parser.add_argument('--phi1', type=str, default='linear', choices=['linear', 'kl', 'chi', 'softplus'])
    parser.add_argument('--phi2', type=str, default='linear', choices=['linear', 'kl', 'chi', 'softplus'])
    parser.add_argument('--init_num_iterations', type=int, default=0)

    # Visualize/Save configurations
    parser.add_argument('--print_every', type=int, default=10, help='print current loss for every x iterations')
    parser.add_argument('--save_content_every', type=int, default=10000, help='save content for resuming every x epochs')
    parser.add_argument('--save_ckpt_every', type=int, default=10000, help='save ckpt every x epochs')
    parser.add_argument('--save_image_every', type=int, default=1000, help='save images every x epochs')
    parser.add_argument('--fid_every', type=int, default=10000, help='monitor FID every x epochs')
    parser.add_argument('--fid', action='store_false', default=True, help="Calculate FID")
    args = parser.parse_args()

    args.size = [3, args.image_size, args.image_size]

    train(args)