import argparse
import torch
import os
import json
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
from models import get_model
from utils import *
from dataset import build_boundary_distribution

def train(args):
    # save configurations
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    
    batch_size = args.batch_size
    nz = args.nz = args.image_size

    # Create log path
    exp = args.exp
    parent_dir = "./train_logs/{}".format(args.problem_name)
    exp_path = os.path.join(parent_dir, exp)
    os.makedirs(exp_path, exist_ok=True)

    jsonstr = json.dumps(args.__dict__, indent=4)
    with open(os.path.join(exp_path, 'config.json'), 'w') as f:
        f.write(jsonstr)
    
    # Make log file
    with open(os.path.join(exp_path, 'log.txt'), 'w') as f:
        f.write("Start Training")
        f.write('\n')

    device = torch.device(f'cuda:0')
    args.device = device

    
    # Get Networks/Optimizer
    netD, netG = get_model(args)
    
    netG = netG.to(device)
    optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
    
    if args.use_ema:
        optimizerG = EMA(optimizerG, ema_decay=args.ema_decay)
    
    if args.lr_scheduler:
        schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_iterations, eta_min=args.eta_min)
    
    netG = nn.DataParallel(netG)

    netD = netD.to(device)
    optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
    
    if args.lr_scheduler:
        schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_iterations, eta_min=args.eta_min)
    
    netD = nn.DataParallel(netD)

    # Resume
    if args.resume:
        checkpoint_file = os.path.join(exp_path, 'content.pth')
        checkpoint = torch.load(checkpoint_file, map_location=device)
        init_iteration = checkpoint['iteration']
        netG.load_state_dict(checkpoint['netG_dict'])
        # load G
        optimizerG.load_state_dict(checkpoint['optimizerG'])
        if args.lr_scheduler:
            schedulerG.load_state_dict(checkpoint['schedulerG'])
        # load D
        netD.load_state_dict(checkpoint['netD_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD'])
        if args.lr_scheduler:
            schedulerD.load_state_dict(checkpoint['schedulerD'])
        print("=> loaded checkpoint (iteration {})".format(checkpoint['iteration']))
    else:
        init_iteration = 0
    
    
    # Get Data
    args.train = True
    source_dataset, target_dataset = build_boundary_distribution(args)

    # Get f-divergence
    phi = select_phi(args.phi2)

    # Get sampler
    sampler = Sampler(args)

    # Start training
    start = datetime.now()
    for iter in range(init_iteration, args.num_iterations):
        #### Update potential ####
        for p in netD.parameters():
            p.requires_grad = True

        for _ in range(args.K_v):        
                 
            netD.zero_grad()

            # get x, y, y_t
            source_data = source_dataset.sample().to(device)
            target_data = target_dataset.sample().to(device)

            with torch.no_grad():
                latent_z = torch.randn(batch_size, nz, device=device)
                generated_data = netG(source_data, latent_z)
                y_t, y_tph, t, tph = sampler.sample_pair(source_data, generated_data)
                t1 = torch.ones_like(t)
            
            # bdy loss
            Dloss_bdy = phi(netD(target_data, t1)).mean() - netD(generated_data, t1).mean()
            Dloss_bdy.backward()

            if args.backward:
                y_tph.requires_grad = True
            else:
                y_t.requires_grad = True

            # running loss
            V_t = netD(y_t, t)
            
            V_tph = netD(y_tph, tph)
            dVdt = (V_tph - V_t) / (tph - t)

            if args.backward:
                dVdx = torch.autograd.grad(V_tph.sum(), y_tph, create_graph=True)[0]
            else:
                dVdx = torch.autograd.grad(V_t.sum(), y_t, create_graph=True)[0]
            
            if args.sigma > 0:
                # Hutchinson skilling with Rademacher sampling
                epsilon = torch.randint_like(dVdx, low=0, high=2, device=dVdx.device, requires_grad=False).float() * 2 - 1.
                dVdx_eps = torch.sum(dVdx * epsilon)

                if args.backward:
                    grad_dVdx_eps = torch.autograd.grad(dVdx_eps, y_tph, create_graph=True)[0]
                else:
                    grad_dVdx_eps = torch.autograd.grad(dVdx_eps, y_t, create_graph=True)[0]

                V_xx = torch.sum(grad_dVdx_eps * epsilon, dim=tuple(range(1, len(y_t.shape))))
            
            norm_sq = 0.5 * args.alpha * (dVdx**2).sum(tuple(range(1, len(y_t.shape))))

            if args.sigma > 0:
                Dloss_running = args.Dlmbda * (dVdt + 0.5 * args.sigma**2 * V_xx  - norm_sq).abs()**args.p - norm_sq
            else:
                Dloss_running = args.Dlmbda * (dVdt - norm_sq).abs()**args.p - norm_sq
                
            
            # if args.use_mean:
            Dloss_running = Dloss_running.mean()
            # else:
            #     Dloss_running = 1 / (args.num_timesteps * sampler.retrieve_prob_t(t)) * Dloss_running
            #     Dloss_running = Dloss_running.mean()

            
            Dloss_running.backward()

            if args.Dclip > 0:
                nn.utils.clip_grad_norm_(netD.parameters(), args.Dclip)

            optimizerD.step()


        #### Update Generator ####
        for p in netD.parameters():
            p.requires_grad = False

        for _ in range(args.K_T):
            netG.zero_grad()

            source_data = source_dataset.sample().to(device)
            latent_z = torch.randn(batch_size, nz, device=device)
            generated_data = netG(source_data, latent_z)
            y_t, y_tph, t, tph = sampler.sample_pair(source_data, generated_data)
            t1 = torch.ones_like(t)                

            V_t = netD(y_t, t)
            V_tph = netD(y_tph, tph)
            dVdt = (V_tph - V_t) / (tph - t)

            if args.backward:
                dVdx = torch.autograd.grad(V_tph.sum(), y_tph, create_graph=True)[0]
            else:
                dVdx = torch.autograd.grad(V_t.sum(), y_t, create_graph=True)[0]


            
            if args.sigma > 0:
                # Hutchinson skilling with Rademacher sampling
                epsilon = torch.randint_like(dVdx, low=0, high=2, device=dVdx.device, requires_grad=False).float() * 2 - 1.
                dVdx_eps = torch.sum(dVdx * epsilon)

                if args.backward:
                    grad_dVdx_eps = torch.autograd.grad(dVdx_eps, y_tph, create_graph=True)[0]
                else:
                    grad_dVdx_eps = torch.autograd.grad(dVdx_eps, y_t, create_graph=True)[0]
                    
                V_xx = torch.sum(grad_dVdx_eps * epsilon, dim=tuple(range(1, len(y_t.shape))))
            
            norm_sq = 0.5 * args.alpha * (dVdx**2).sum(tuple(range(1, len(y_t.shape))))

            if args.sigma > 0:
                errG = args.Glmbda * (dVdt - norm_sq + 0.5 * args.sigma**2 * V_xx) 
            else:
                errG = args.Glmbda * (dVdt - norm_sq)

            if args.use_mean:
                errG = errG.mean()
            else:
                errG = 1 / (args.num_timesteps * sampler.retrieve_prob_t(t)) * errG
                errG = errG.mean()

            errG.backward()

            if args.Gclip > 0:
                nn.utils.clip_grad_norm_(netG.parameters(), args.Gclip)

            optimizerG.step()

        #### Update Schedulers
        if args.lr_scheduler:
            schedulerG.step()
            schedulerD.step()
    
        #### Visualizations and Save ####
        ## save losses
        if (iter + 1) % args.print_every == 0:
            with open(os.path.join(exp_path, 'log.txt'), 'a') as f:
                f.write(f'Iteration {iter + 1:07d} : G Loss {errG.item():.4f}, D Loss bdy {Dloss_bdy.item():.4f}, D Loss running {Dloss_running.item():.4f}  Elapsed {datetime.now() - start}')
                f.write('\n')

        # save content
        if (iter + 1) % args.save_content_every == 0:
            print('Saving content.')
            if args.lr_scheduler:
                content = {'iteration': iter + 1, 'args': args,
                            'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
                            'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
                            'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
            else:
                content = {'iteration': iter + 1, 'args': args,
                            'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
                            'netD_dict': netD.state_dict(), 'optimizerD': optimizerD.state_dict()}
            
            torch.save(content, os.path.join(exp_path, 'content.pth'))
        
        # save checkpoint
        if (iter + 1) % args.save_ckpt_every == 0:
            if args.use_ema:
                optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
            torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(iter + 1)))
            if args.use_ema:
                optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
            
            torch.save(netD.state_dict(), os.path.join(exp_path, 'netD_{}.pth'.format(iter + 1)))        

        # save generated images
        if (iter + 1) % args.save_image_every == 0:
            NUM_SAMPLES = 5000
            REAL_SAMPLES = [target_dataset.sample() for i in range(NUM_SAMPLES//batch_size)]
            REAL_SAMPLES = np.concatenate(REAL_SAMPLES)

            # fake data samples
            FAKE_SAMPLES = []
            NOISES = []

            with torch.no_grad():
                for _ in range(NUM_SAMPLES//batch_size):
                    noise = source_dataset.sample().to(device)
                    latent_z = torch.randn(batch_size, nz, device=device)
                    FAKE_SAMPLES.append(netG(noise, latent_z).cpu().numpy())
                    NOISES.append(noise.cpu().numpy())
            FAKE_SAMPLES = np.concatenate(FAKE_SAMPLES)
            NOISES = np.concatenate(NOISES)                

            np.savez(os.path.join(exp_path, 'iter_{}.npz'.format(iter + 1)),{'real': REAL_SAMPLES, 'fake': FAKE_SAMPLES, 'noise': NOISES})



if __name__ == '__main__':
    parser = argparse.ArgumentParser('Simulation-free EUOT parameters')

    # Important Hyperparameters
    parser.add_argument('--Gloss', default='value', choices=['rl', 'value', 'adversarial'])
    parser.add_argument('--Dloss', default='hjb', choices=['mf', 'hjb', 'value'])
    parser.add_argument('--Glmbda', type=float, default=1, help='proportion of the cost c')
    parser.add_argument('--Dlmbda', type=float, default=1, help='regularization for value function')
    parser.add_argument('--p', type=float, default=1, help='power of hjb')
    parser.add_argument('--alpha', type=float, default=1, help='proportion between bdy loss and running loss')
    
    parser.add_argument('--Dclip', type=float, default=0, help='Clip the gradient if the clip value is positive (>0)')
    parser.add_argument('--Gclip', type=float, default=0, help='Clip the gradient if the clip value is positive (>0)')
    
    parser.add_argument('--num_timesteps', type=int, default=20, help='The number of timesteps')
    parser.add_argument('--time_sample', type=str, default='uniform', choices=['uniform', 'linear'], help='sampling t method')

    # Experiment description
    parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
    parser.add_argument('--exp', default='temp', help='name of the experiment')
    parser.add_argument('--resume', action='store_true',default=False, help='Resume training or not')
    parser.add_argument('--problem_name', default='8gaussian', choices=['1d_2gaussian',
                                                                 'checkerboard', 
                                                                 '8gaussian', 
                                                                 '25gaussian', 
                                                                 'moon_to_spiral', 
                                                                 'gaussian0.1_to_gaussian-0.1'], help='name of dataset')
    parser.add_argument('--image_size', type=int, default=2, help='size of image (or data)')
    
    # Network configurations
    parser.add_argument('--model_name', default='toy', choices=['ncsnpp', 'ddpm', 'drunet', 'otm', 'toy'], help='Choose default model')
    
    # Training/Optimizer configurations
    parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
    parser.add_argument('--num_iterations', type=int, default=120000, help='the number of iterations')
    parser.add_argument('--lr_g', type=float, default=2.0e-4, help='learning rate g')
    parser.add_argument('--lr_d', type=float, default=1.0e-4, help='learning rate d')
    parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for adam')
    parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for adam')
    parser.add_argument('--use_ema', action='store_true', default=False, help='use EMA or not')
    parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
    parser.add_argument('--lr_scheduler', action='store_true', default=False, help='Use lr scheduler or not. We use cosine scheduler if the argument is activated.')
    parser.add_argument('--eta_min', type=float, default=5e-5, help='eta_min of lr_scheduler')
    parser.add_argument('--K_v', type=int, default=1)
    parser.add_argument('--K_T', type=int, default=1)
    
    # Loss configurations
    parser.add_argument('--phi2', type=str, default='softplus', choices=['linear', 'kl', 'softplus'], help='Set g2')
    parser.add_argument('--sigma', type=float, default=0.1, help='diffusion term')    
    parser.add_argument('--use_mean', action='store_true', default=False, help='If activated, just use mean without considering sample_t')
    parser.add_argument('--backward', action='store_true', default=False)
    
    # Visualize/Save configurations
    parser.add_argument('--print_every', type=int, default=2000, help='print current loss for every x iterations')
    parser.add_argument('--save_content_every', type=int, default=2000, help='save content for resuming every x epochs')
    parser.add_argument('--save_ckpt_every', type=int, default=10000, help='save ckpt every x epochs')
    parser.add_argument('--save_image_every', type=int, default=2000, help='save images every x epochs')
    args = parser.parse_args()

    args.size = [args.image_size]

    train(args)