import argparse
import torch
import os
import json
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torchvision
from datetime import datetime
from models import get_model
from utils import *
from data.dataset import build_boundary_distribution


def train(args):
    # save configurations
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    batch_size = args.batch_size
    nz = args.nz = args.image_size   

    device = 'cuda:0'
    args.device = device
    
    # Get Networks/Optimizer
    netD, netG1 = get_model(args)
    _, netG2 = get_model(args)
    
    netG1 = netG1.to(device)
    optimizerG1 = optim.Adam(netG1.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
    if args.use_ema:
        optimizerG1 = EMA(optimizerG1, ema_decay=args.ema_decay)
    if args.lr_scheduler:
        schedulerG1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG1, args.num_iterations, eta_min=args.eta_min)
    netG1 = nn.DataParallel(netG1)

    netG2 = netG2.to(device)
    optimizerG2 = optim.Adam(netG2.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
    if args.use_ema:
        optimizerG2 = EMA(optimizerG2, ema_decay=args.ema_decay)
    if args.lr_scheduler:
        schedulerG2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG2, args.num_iterations, eta_min=args.eta_min)
    netG2 = nn.DataParallel(netG2)

    netD = netD.to(device)
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))    
    if args.lr_scheduler:
        schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_iterations, eta_min=args.eta_min)
    netD = nn.DataParallel(netD)

    
    # Get Data
    args.train = True
    source_dataset, target_dataset = build_boundary_distribution(args)

    # Sampler
    sampler = Sampler(args)
    phi1, phi2 = select_phi(args.phi1), select_phi(args.phi2)

    # Get Evaluation tools
    evaltool = EvalAdapted(args)
    logger = Logger(args, evaltool)

    # Start training
    start = datetime.now()
    for iter in range(args.num_iterations):
        #### Update potential ####
        for p in netD.parameters():
            p.requires_grad = True

        for _ in range(args.K_v):        
                 
            netD.zero_grad()

            # Sample x, t
            with torch.no_grad():
                x0 = source_dataset.sample().to(device)
                t = sampler.sample_t(batch_size).to(device)
                latent_z = torch.randn(batch_size, nz, device=device)
                x1 = netG1(x0, latent_z)
                xt = sampler(t, x0, x1)

            t.requires_grad = True
            xt.requires_grad = True

            Vx = netD(t, xt)
            cost = args.tau * t * torch.sum((x1 - x0).reshape(batch_size, -1)**2, dim=1)
            errDx = phi1(Vx - cost).mean()

            Vx_t = torch.autograd.grad(Vx.sum(), t, create_graph=True)[0]
            Vx_x = torch.autograd.grad(Vx.sum(), xt, create_graph=True)[0].reshape(batch_size, -1)
            HJBx = 2 *args.tau * Vx_t + 0.5 * torch.norm(Vx_x, dim=1)**2
            HJBx = HJBx.abs().mean()
            errDx = errDx + args.lmbda * HJBx

            errDx.backward()


            with torch.no_grad():
                y1 = target_dataset.sample().to(device)
                t = sampler.sample_t(batch_size).to(device)
                latent_z = torch.randn(batch_size, nz, device=device)
                y0 = netG2(y1, latent_z)
                yt = sampler(t, y0, y1)
                
            t.requires_grad = True
            yt.requires_grad = True

            cost = args.tau * (1 - t) * torch.sum((y1 - y0).reshape(batch_size, -1)**2, dim=1)            
            Vy = netD(t, yt)
            errDy = phi2(- Vy - cost).mean()

            Vy_t = torch.autograd.grad(Vy.sum(), t, create_graph=True)[0]
            Vy_y = torch.autograd.grad(Vy.sum(), yt, create_graph=True)[0].reshape(batch_size, -1)
            HJBy = 2 * args.tau * Vy_t + 0.5 * torch.norm(Vy_y, dim=1)**2
            HJBy = HJBy.abs().mean()
            errDy = errDy + args.lmbda * HJBy

            errDy.backward()

            # if args.Dclip > 0:
            #     nn.utils.clip_grad_norm_(netD.parameters(), args.Dclip)

            optimizerD.step()


        #### Update Generator ####
        for p in netD.parameters():
            p.requires_grad = False

        for _ in range(args.K_T):
            netG1.zero_grad()

            t = sampler.sample_t(batch_size).to(device)
            x0 = source_dataset.sample().to(device)
            latent_z = torch.randn(batch_size, nz, device=device)
            x1 = netG1(x0, latent_z)
            xt = sampler(t, x0, x1)

            Vx = netD(t, xt)
            cost = args.tau * t * torch.sum((x1 - x0).reshape(batch_size, -1)**2, dim=1)
            errG1 = cost.mean() - Vx.mean()
            
            errG1.backward()
            # if args.Gclip > 0:
            #     nn.utils.clip_grad_norm_(netG1.parameters(), args.Gclip)
            optimizerG1.step()


            netG2.zero_grad()

            t = sampler.sample_t(batch_size).to(device)
            y1 = target_dataset.sample().to(device)
            latent_z = torch.randn(batch_size, nz, device=device)
            y0 = netG2(y1, latent_z)
            yt = sampler(t, y0, y1)

            Vy = netD(t, yt)
            cost = args.tau * (1 - t) * torch.sum((y1 - y0).reshape(batch_size, -1)**2, dim=1)
            errG2 = cost.mean() + Vy.mean()
            
            errG2.backward()
            # if args.Gclip > 0:
            #     nn.utils.clip_grad_norm_(netG2.parameters(), args.Gclip)
            optimizerG2.step()
            
        #### Update Schedulers
        if args.lr_scheduler:
            schedulerG1.step()
            schedulerG2.step()
            schedulerD.step()
        
    
        #### Visualizations and Save ####
        log = f'Iteration {iter + 1:07d} : G1 Loss {errG1.item():.4f}, G2 Loss {errG2.item():.4f}, Dx Loss {errDx.item():.4f}, Dy Loss {errDy.item():.4f}, HJBx Loss {HJBx.item():.4f}, HJBy Loss {HJBy.item():.4f}, Elapsed {datetime.now() - start}'
        logger(log)
        info = {'netG1': netG1, 'netG2': netG2, 'netD': netD, 'optimizerG1': optimizerG1, 'optimizerG2': optimizerG2, 'optimizerD': optimizerD}
        logger.save_ckpt(info)
        logger.step()

        # save generated images
        if (iter + 1) % args.save_image_every == 0:
            NUM_SAMPLES = 5000
            REAL_SAMPLES = [target_dataset.sample() for i in range(NUM_SAMPLES//batch_size)]
            REAL_SAMPLES = np.concatenate(REAL_SAMPLES)

            # fake data samples
            FAKE_SAMPLES = []
            NOISES = []

            with torch.no_grad():
                for _ in range(NUM_SAMPLES//batch_size):
                    noise = source_dataset.sample().to(device)
                    latent_z = torch.randn(batch_size, nz, device=device)
                    FAKE_SAMPLES.append(netG1(noise, latent_z).cpu().numpy())
                    NOISES.append(noise.cpu().numpy())
            FAKE_SAMPLES = np.concatenate(FAKE_SAMPLES)
            NOISES = np.concatenate(NOISES)

            np.savez(os.path.join(logger.exp_path, 'iter_{}.npz'.format(iter + 1)),{'real': REAL_SAMPLES, 'fake': FAKE_SAMPLES, 'noise': NOISES})


if __name__ == '__main__':
    parser = argparse.ArgumentParser('DIOTM Parameters')

    # Experiment description
    parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
    parser.add_argument('--exp', default='temp', help='name of the experiment')
    parser.add_argument('--problem_name', default='moon_to_spiral', choices=[
                                                                 'moon_to_spiral',
                                                                 'spiral_to_moon',
                                                                 'gaussian_to_8gaussian',
                                                                 '8gaussian_to_gaussian',
                                                                 'gaussian_to_25gaussian',
                                                                 '25gaussian_to_gaussian',
                                                                 'gaussian_to_twocircles',
                                                                 'twocircles_to_gaussian'
                                                                 ], help='name of dataset')

    # Hyperparameters
    parser.add_argument('--num_iterations', type=int, default=120000, help='the number of iterations')
    parser.add_argument('--K_v', type=int, default=1)
    parser.add_argument('--K_T', type=int, default=1)
    parser.add_argument('--image_size', type=int, default=2, help='size of image (or data)')
    parser.add_argument('--time_sample', type=str, default='uniform', choices=['uniform', 'beta51', 'beta22', 't0.5', 't0.8', 't0.9', 't1', 'discrete_2', 'discrete_4', 'discrete_8'], help='sampling t method')
    parser.add_argument('--lmbda', type=float, default=0, help='regularization parameter (HJB)')
    parser.add_argument('--tau', type=float, default=0.1, help='cost coefficient')

    # Visualize/Save configurations
    parser.add_argument('--print_every', type=int, default=10, help='print current loss for every x iterations')
    parser.add_argument('--save_content_every', type=int, default=10000, help='save content for resuming every x epochs')
    parser.add_argument('--save_ckpt_every', type=int, default=10000, help='save ckpt every x epochs')
    parser.add_argument('--save_image_every', type=int, default=5000, help='save images every x epochs')

    # fixed
    parser.add_argument('--batch_size', type=int, default=200, help='input batch size')
    parser.add_argument('--model_name', type=str, default='toy')
    parser.add_argument('--lr_g', type=float, default=1.0e-4, help='learning rate g')
    parser.add_argument('--lr_d', type=float, default=1.0e-4, help='learning rate d')
    parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for adam')
    parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for adam')
    parser.add_argument('--use_ema', action='store_true', default=False, help='use EMA or not')
    parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
    parser.add_argument('--lr_scheduler', action='store_true', default=False, help='Use lr scheduler or not. We use cosine scheduler if the argument is activated.')
    parser.add_argument('--eta_min', type=float, default=1e-5, help='eta_min of lr_scheduler')
    parser.add_argument('--fid_every', type=int, default=10000, help='monitor FID every x epochs')
    parser.add_argument('--fid', action='store_false', default=True, help="Calculate FID")
    args = parser.parse_args()

    args.size = [args.image_size]
    args.phi1 = args.phi2 = 'linear'

    train(args)

