import argparse
import torch
import os
import json
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torchvision
from datetime import datetime
from models import get_model
from utils import *
from dataset import build_boundary_distribution
from pytorch_fid.fid_score import calculate_fid_given_paths

def train(args):
    # save configurations
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    batch_size = args.batch_size
    nz = args.nz

    # Create log path
    exp = args.exp
    parent_dir = "./train_logs/{}".format(args.problem_name)
    exp_path = os.path.join(parent_dir, exp)
    os.makedirs(exp_path, exist_ok=True)

    jsonstr = json.dumps(args.__dict__, indent=4)
    with open(os.path.join(exp_path, 'config.json'), 'w') as f:
        f.write(jsonstr)
    
    # Make log file
    with open(os.path.join(exp_path, 'log.txt'), 'w') as f:
        f.write("Start Training")
        f.write('\n')

    device = torch.device(f'cuda:0')
    args.device = device

    
    # Get Networks/Optimizer
    netD, netG = get_model(args)
    
    netG = netG.to(device)
    optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
    
    if args.use_ema:
        optimizerG = EMA(optimizerG, ema_decay=args.ema_decay)
    
    if args.lr_scheduler:
        schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_iterations, eta_min=args.eta_min)
    
    netG = nn.DataParallel(netG)

    netD = netD.to(device)
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
    
    if args.lr_scheduler:
        schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_iterations, eta_min=args.eta_min)
    
    netD = nn.DataParallel(netD)

    # Resume
    if args.resume:
        checkpoint_file = os.path.join(exp_path, 'content.pth')
        checkpoint = torch.load(checkpoint_file, map_location=device)
        init_iteration = checkpoint['iteration']
        netG.load_state_dict(checkpoint['netG_dict'])
        # load G
        optimizerG.load_state_dict(checkpoint['optimizerG'])
        if args.lr_scheduler:
            schedulerG.load_state_dict(checkpoint['schedulerG'])
        # load D
        netD.load_state_dict(checkpoint['netD_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD'])
        if args.lr_scheduler:
            schedulerD.load_state_dict(checkpoint['schedulerD'])
        print("=> loaded checkpoint (iteration {})".format(checkpoint['iteration']))
    else:
        init_iteration = 0
    
    
    # Get Data
    args.train = True
    source_dataset, target_dataset = build_boundary_distribution(args)
    args.train = False
    source_test_dataset, target_test_dataset = build_boundary_distribution(args)

    # If image, get FID path & statistics
    if args.fid:
        FID_img_path = os.path.join(exp_path, 'generated_samples')
        os.makedirs(FID_img_path, exist_ok=True) 
        if args.problem_name == 'cifar10':
            real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
        elif args.problem_name == 'celeba_256':
            real_img_dir = 'pytorch_fid/celeba_256_stat.npy'
        elif args.problem_name == 'male_to_female':
            real_img_dir = 'pytorch_fid/celeba64_female.npy'
        elif args.problem_name == 'handbags_to_shoes':
            real_img_dir = 'pytorch_fid/shoes_64.npy'
            args.fid = False

        # elif args.problem_name.find('_to_') != -1:
        #     # real_img_dir = target_test_dataset.dataloader.dataset.data_path
        #     real_img_dir = 'pytorch_fid/celeba64_female.npy'
        else:
            args.fid = False


    # Get f-divergence
    phi = select_phi(args.phi2)

    # Get sampler
    sampler = Sampler(args)

    # Start training
    start = datetime.now()
    for iter in range(init_iteration, args.num_iterations):
        #### Update potential ####
        for p in netD.parameters():
            p.requires_grad = True

        for _ in range(args.K_v):        
                 
            netD.zero_grad()

            # get x, y, y_t
            source_data = source_dataset.sample().to(device)
            target_data = target_dataset.sample().to(device)

            with torch.no_grad():
                latent_z = torch.randn(batch_size, nz, device=device)
                generated_data = netG(source_data, latent_z)
                y_t, y_tph, t, tph = sampler.sample_pair(source_data, generated_data)
                t1 = torch.ones_like(t)
            
            # bdy loss
            Dloss_bdy = phi(netD(target_data, t1)).mean() - netD(generated_data, t1).mean()
            Dloss_bdy.backward()

            # running loss
            if args.backward:
                y_tph.requires_grad = True
            else:
                y_t.requires_grad = True
            V_t = netD(y_t, t)
            
            V_tph = netD(y_tph, tph)
            dVdt = (V_tph - V_t) / (tph - t)
            
            if args.backward:
                dVdx = torch.autograd.grad(V_tph.sum(), y_tph, create_graph=True)[0]
            else:
                dVdx = torch.autograd.grad(V_t.sum(), y_t, create_graph=True)[0]
            if args.sigma > 0:
                # Hutchinson skilling with Rademacher sampling
                epsilon = torch.randint_like(dVdx, low=0, high=2, device=dVdx.device, requires_grad=False).float() * 2 - 1.
                dVdx_eps = torch.sum(dVdx * epsilon)
                if args.backward:
                    grad_dVdx_eps = torch.autograd.grad(dVdx_eps, y_tph, create_graph=True)[0]
                else:
                    grad_dVdx_eps = torch.autograd.grad(dVdx_eps, y_t, create_graph=True)[0]
                V_xx = torch.sum(grad_dVdx_eps * epsilon, dim=tuple(range(1, len(y_t.shape))))
            
            norm_sq = (0.5 * args.alpha - args.reg) * (dVdx**2).sum(tuple(range(1, len(y_t.shape))))

            if args.sigma > 0:
                Dloss_running = args.Dlmbda * (dVdt + 0.5 * args.sigma**2 * V_xx  - norm_sq)**2 - norm_sq
            else:
                Dloss_running = args.Dlmbda * (dVdt - norm_sq)**2 - norm_sq

            # if args.use_mean:
            #     Dloss_running = Dloss_running.mean()
            # else:
            #     Dloss_running = 1 / (args.num_timesteps * sampler.retrieve_prob_t(t)) * Dloss_running
            Dloss_running = Dloss_running.mean()

            Dloss_running.backward()

            if args.Dclip > 0:
                nn.utils.clip_grad_norm_(netD.parameters(), args.Dclip)

            optimizerD.step()


        #### Update Generator ####
        for p in netD.parameters():
            p.requires_grad = False

        for _ in range(args.K_T):
            netG.zero_grad()

            source_data = source_dataset.sample().to(device)
            latent_z = torch.randn(batch_size, nz, device=device)
            generated_data = netG(source_data, latent_z)
            y_t, y_tph, t, tph = sampler.sample_pair(source_data, generated_data)
            t1 = torch.ones_like(t)

            V_t = netD(y_t, t)
            V_tph = netD(y_tph, tph)
            dVdt = (V_tph - V_t) / (tph - t)

            if args.backward:
                dVdx = torch.autograd.grad(V_tph.sum(), y_tph, create_graph=True)[0]
            else:
                dVdx = torch.autograd.grad(V_t.sum(), y_t, create_graph=True)[0]

            if args.sigma > 0:
                # Hutchinson skilling with Rademacher sampling
                epsilon = torch.randint_like(dVdx, low=0, high=2, device=dVdx.device, requires_grad=False).float() * 2 - 1.
                dVdx_eps = torch.sum(dVdx * epsilon)

                if args.backward:
                    grad_dVdx_eps = torch.autograd.grad(dVdx_eps, y_tph, create_graph=True)[0]
                else:
                    grad_dVdx_eps = torch.autograd.grad(dVdx_eps, y_t, create_graph=True)[0]
                V_xx = torch.sum(grad_dVdx_eps * epsilon, dim=tuple(range(1, len(y_t.shape))))
            
            norm_sq = 0.5 * args.alpha * (dVdx**2).sum(tuple(range(1, len(y_t.shape))))

            if args.sigma > 0:
                errG = args.Glmbda * (dVdt - norm_sq + 0.5 * args.sigma**2 * V_xx) 
            else:
                errG = args.Glmbda * (dVdt - norm_sq)

            if args.use_mean:
                errG = errG.mean()
            else:
                errG = 1 / (args.num_timesteps * sampler.retrieve_prob_t(t)) * errG
                errG = errG.mean()

            errG.backward()

            if args.Gclip > 0:
                nn.utils.clip_grad_norm_(netG.parameters(), args.Gclip)

            optimizerG.step()

        #### Update Schedulers
        if args.lr_scheduler:
            schedulerG.step()
            schedulerD.step()
    
        #### Visualizations and Save ####
        ## save losses
        if (iter + 1) % args.print_every == 0:
            with open(os.path.join(exp_path, 'log.txt'), 'a') as f:
                f.write(f'Iteration {iter + 1:07d} : G Loss {errG.item():.4f}, D Loss bdy {Dloss_bdy.item():.4f}, D Loss running {Dloss_running.item():.4f}  Elapsed {datetime.now() - start}')
                f.write('\n')

        # save content
        if (iter + 1) % args.save_content_every == 0:
            print('Saving content.')
            if args.lr_scheduler:
                content = {'iteration': iter + 1, 'args': args,
                            'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
                            'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
                            'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
            else:
                content = {'iteration': iter + 1, 'args': args,
                            'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
                            'netD_dict': netD.state_dict(), 'optimizerD': optimizerD.state_dict()}
            
            torch.save(content, os.path.join(exp_path, 'content.pth'))
        
        # save checkpoint
        if (iter + 1) % args.save_ckpt_every == 0:
            if args.use_ema:
                optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
            torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(iter + 1)))
            if args.use_ema:
                optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
            
            torch.save(netD.state_dict(), os.path.join(exp_path, 'netD_{}.pth'.format(iter + 1)))        

        # save generated images
        if (iter + 1) % args.save_image_every == 0:
            with torch.no_grad():
                source_data = source_test_dataset.sample().to(device)
                latent_z = torch.randn(batch_size, nz, device=device)
                images = netG(source_data, latent_z)
                images = (0.5*(images+1)).detach().cpu()
                source_data = (0.5*(source_data+1)).detach().cpu()
                torchvision.utils.save_image(images, os.path.join(exp_path, 'iter_{}.png'.format(iter+1)))
                torchvision.utils.save_image(source_data, os.path.join(exp_path, 'iter_{}_source.png'.format(iter+1)))
        
        # calculate fid
        if (iter + 1) % args.fid_every == 0 and args.fid:
            # use ema model
            if args.use_ema:
                optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
            
            iters_needed = 50000 // batch_size
            
            for i in range(iters_needed):
                with torch.no_grad():
                    source_data = source_test_dataset.sample().to(device)
                    latent_z = torch.randn(batch_size, nz, device=device)
                    fake_sample = netG(source_data, latent_z)
                    fake_sample = (fake_sample + 1.) / 2.

                    
                    for j, x in enumerate(fake_sample):
                        index = i * args.batch_size + j 
                        torchvision.utils.save_image(x, os.path.join(exp_path,'generated_samples/{}.jpg'.format(index)))
                    
                    print('generating batch ', i, end='\r')
        
            paths = [FID_img_path, real_img_dir]
        
            kwargs = {'batch_size': 100, 'device': device, 'dims': 2048}
            fid = calculate_fid_given_paths(paths=paths, **kwargs)
            print(fid)
            
            with open(os.path.join(exp_path, 'log.txt'), 'a') as f:
                f.write(f'Iteration {iter + 1:04d} FID : {fid}')
                f.write('\n')
            
            if args.use_ema:
                optimizerG.swap_parameters_with_ema(store_params_in_ema=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('Simulation-free EUOT parameters')

    # Important Hyperparameters
    parser.add_argument('--Gloss', default='value', choices=['rl', 'value', 'adversarial'])
    parser.add_argument('--Dloss', default='hjb', choices=['mf', 'hjb', 'value'])
    parser.add_argument('--Glmbda', type=float, default=1, help='regularization for transport function')
    parser.add_argument('--Dlmbda', type=float, default=1, help='regularization for value function')
    parser.add_argument('--reg', type=float, default=0, help='R1 regularization')
    parser.add_argument('--alpha', type=float, default=1, help='proportion between bdy loss and running loss')

    parser.add_argument('--Dclip', type=float, default=0, help='Clip the gradient if the clip value is positive (>0)')
    parser.add_argument('--Gclip', type=float, default=0, help='Clip the gradient if the clip value is positive (>0)')
    
    parser.add_argument('--num_timesteps', type=int, default=20, help='The number of timesteps')
    parser.add_argument('--time_sample', type=str, default='uniform', choices=['uniform', 'linear'], help='sampling t method')

    # Experiment description
    parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
    parser.add_argument('--exp', default='temp', help='name of the experiment')
    parser.add_argument('--resume', action='store_true',default=False, help='Resume training or not')
    parser.add_argument('--problem_name', default='cifar10', choices=[
                                                                 'cifar10', 
                                                                 'male_to_female',
                                                                 'wild_to_cat'
                                                                 ], help='name of dataset')
    parser.add_argument('--image_size', type=int, default=32, help='size of image (or data)')
    parser.add_argument('--num_channels', type=int, default=3, help='channel of image')
    
    # Network configurations
    parser.add_argument('--model_name', default='ncsnpp', choices=['ncsnpp'], help='Choose default model')
    parser.add_argument('--centered', action='store_false', default=True, help='-1,1 scale')
    parser.add_argument('--num_channels_dae', type=int, default=128, help='number of initial channels in denoising model')
    parser.add_argument('--n_mlp', type=int, default=4, help='number of mlp layers for z')
    parser.add_argument('--ch_mult', nargs='+', type=int, default=[1,2,2,2], help='channel multiplier')
    parser.add_argument('--num_res_blocks', type=int, default=2, help='number of resnet blocks per scale')
    parser.add_argument('--attn_resolutions', default=(16,), help='resolution of applying attention')
    parser.add_argument('--dropout', type=float, default=0., help='drop-out rate')
    parser.add_argument('--resamp_with_conv', action='store_false', default=True, help='always up/down sampling with conv')
    parser.add_argument('--fir', action='store_false', default=True, help='FIR')
    parser.add_argument('--fir_kernel', default=[1, 3, 3, 1], help='FIR kernel')
    parser.add_argument('--skip_rescale', action='store_false', default=True, help='skip rescale')
    parser.add_argument('--resblock_type', default='biggan', help='tyle of resnet block, choice in biggan and ddpm')
    parser.add_argument('--progressive', type=str, default='none', choices=['none', 'output_skip', 'residual'], help='progressive type for output')
    parser.add_argument('--progressive_input', type=str, default='residual', choices=['none', 'input_skip', 'residual'], help='progressive type for input')
    parser.add_argument('--progressive_combine', type=str, default='sum', choices=['sum', 'cat'], help='progressive combine method.')
    parser.add_argument('--embedding_type', type=str, default='positional', choices=['positional', 'fourier'], help='type of time embedding')
    parser.add_argument('--fourier_scale', type=float, default=16., help='scale of fourier transform')
    parser.add_argument('--not_use_tanh', action='store_true', default=False, help='use tanh for last layer')
    parser.add_argument('--z_emb_dim', type=int, default=256, help='embedding dimension of z')
    parser.add_argument('--t_emb_dim', type=int, default=256, help='embedding dimension of z')
    parser.add_argument('--nz', type=int, default=100, help='latent dimension')
    parser.add_argument('--ngf', type=int, default=64, help='The default number of channels of model')
    
    # Training/Optimizer configurations
    parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
    parser.add_argument('--num_iterations', type=int, default=120000, help='the number of iterations')
    parser.add_argument('--lr_g', type=float, default=1.0e-4, help='learning rate g')
    parser.add_argument('--lr_d', type=float, default=1.0e-4, help='learning rate d')
    parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for adam')
    parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for adam')
    parser.add_argument('--use_ema', action='store_true', default=False, help='use EMA or not')
    parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
    parser.add_argument('--lr_scheduler', action='store_true', default=False, help='Use lr scheduler or not. We use cosine scheduler if the argument is activated.')
    parser.add_argument('--eta_min', type=float, default=1e-5, help='eta_min of lr_scheduler')
    parser.add_argument('--K_v', type=int, default=1)
    parser.add_argument('--K_T', type=int, default=1)
    
    # Loss configurations
    parser.add_argument('--phi2', type=str, default='softplus', choices=['linear', 'kl', 'softplus'], help='Set g2')
    parser.add_argument('--sigma', type=float, default=0.1, help='diffusion term')
    parser.add_argument('--use_mean', action='store_true', default=False, help='If activated, just use mean without considering sample_t')
    parser.add_argument('--backward', action='store_true', default=False)
    
    # Visualize/Save configurations
    parser.add_argument('--print_every', type=int, default=10, help='print current loss for every x iterations')
    parser.add_argument('--save_content_every', type=int, default=10000, help='save content for resuming every x epochs')
    parser.add_argument('--save_ckpt_every', type=int, default=10000, help='save ckpt every x epochs')
    parser.add_argument('--save_image_every', type=int, default=1000, help='save images every x epochs')
    parser.add_argument('--fid_every', type=int, default=10000, help='monitor FID every x epochs')
    parser.add_argument('--fid', action='store_false', default=True, help="Calculate FID")
    args = parser.parse_args()

    args.size = [3, args.image_size, args.image_size]

    train(args)