import numpy as np
import torch
import torch.nn as nn
import json
from torch.autograd import grad
import matplotlib.pyplot as plt
import argparse
import network
import pickle
import time

dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/WPONG2d.json", type=str)
    parser.add_argument("--save_var", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument("--random_seed", default=0, type=int)
    parser.add_argument('--verbose', default=True, action=argparse.BooleanOptionalAction)
    parser.add_argument("--net", default=None, type=str)
    # command for no verbose: --no-verbose
    args = parser.parse_args()
    return args

def divergence(y, x, d):
    # compute the divergence of y w.r.t. x, both are N x d, output is N x 1
    div = 0
    for i in range(d):
        one = torch.ones_like(y[:,0])
        grad_y = grad(y[:,i], x, grad_outputs=one, retain_graph=True,create_graph=True)[0]
        div = div + grad_y[:,i:i+1]
    return div

def Jacobian(y, x, dy, dx):
    # compute the Jacobian of y w.r.t. x, y is N x dy, x is N x dx, output is N x dy x dx
    J = torch.zeros(y.size(0),dy,dx, dtype=dtype, device=device)
    one = torch.ones_like(y[:,0])
    for i in range(dy):
        grad_y = grad(y[:,i], x, grad_outputs=one, retain_graph=True,create_graph=True)[0] # N x dx
        J[:,i,:] = grad_y
    return J

def sample_x0(N,d): # standard normal distribution
    return torch.randn(N, d, dtype=dtype, device=device)

def sample_x0_unif(N,d): # uniform sample in a disk with radius 1
    x = torch.randn(N, d, dtype=dtype, device=device)
    x = x / torch.sqrt(torch.sum(x**2,dim=-1,keepdim=True))
    r = torch.rand(N,1, dtype=dtype, device=device) ** (1/d)
    x = r * x
    return x

def rho0(x,d): # initial density N(0,1)
    return torch.exp(-torch.sum(x**2,dim=-1,keepdim=True)/2)/((2*np.pi)**(d/2))

def score0(x): # initial score
    return -x

def main():
    start_time = time.time()
    args = parse_args()
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    with open(args.config) as f:
        config = json.load(f)

    d = config['d']
    gamma = config['gamma']
    Nx = config['Nx']
    T = config['T']
    Nt = config['Nt']
    lr = config['lr']
    n_step = config['n_step']
    milestones = config['milestones']
    decay = config['decay']
    net_width = config['net_width']
    logging_freq = config['logging_freq']
    a = config['a']
    dt = T / Nt
    if args.save_var:
        logging_freq = 1

    c1 = torch.tensor([[a,a]], dtype=dtype, device=device)
    c2 = torch.tensor([[-a,-a]], dtype=dtype, device=device)
    coef = 4

    def g(x): # terminal cost
        # input x is N x 2, output is N x 1
        part1 = torch.sum((x - c1)**2, dim=-1, keepdim=True)
        part2 = torch.sum((x - c2)**2, dim=-1, keepdim=True)
        return part1 * part2 / coef

    def V_x(x): # gradient of terminal cost
        # input x is N x 2, output is N x 2
        part1 = torch.sum((x - c1)**2, dim=-1, keepdim=True)
        part2 = torch.sum((x - c2)**2, dim=-1, keepdim=True)
        return 2* (x - c1) * part2 / coef + 2* (x - c2) * part1 / coef

    # load reference solutions
    with open('fcn4_interp.pkl', 'rb') as f:
        interpolators = pickle.load(f)
    
    if args.net is not None:
        config["net_type"] = args.net
        infix = infix + args.net
    w_net = getattr(network,config["net_type"])(d, d, net_width)
    w_net.type(dtype).to(device)

    rhoT_interp = interpolators['rhoT_interp']
    scoreT_1_interp = interpolators['scoreT_1_interp']
    scoreT_2_interp = interpolators['scoreT_2_interp']
    w0_1_interp = interpolators['w0_1_interp']
    w0_2_interp = interpolators['w0_2_interp']
    wT_1_interp = interpolators['wT_1_interp']
    wT_2_interp = interpolators['wT_2_interp']
    del interpolators

    def scoreT(x): # terminal score
        # input x is N x 2, output is N x 2
        scoreT_1 = scoreT_1_interp(x)
        scoreT_2 = scoreT_2_interp(x)
        return np.stack([scoreT_1, scoreT_2], axis=-1)

    def w0(x): # initial drift
        # input x is N x 2, output is N x 2
        w0_1 = w0_1_interp(x)
        w0_2 = w0_2_interp(x)
        return np.stack([w0_1, w0_2], axis=-1)

    def wT(x): # terminal drift
        # input x is N x 2, output is N x 2
        wT_1 = wT_1_interp(x)
        wT_2 = wT_2_interp(x)
        return np.stack([wT_1, wT_2], axis=-1)

    optimizer = torch.optim.Adam(w_net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=decay)

    # ==================== training ====================
    for step in range(n_step+1):
        optimizer.zero_grad()
        x = sample_x0(Nx,d)
        x.requires_grad_(True)
        z = score0(x)
        loss_cost = 0
        for t_idx in range(Nt):
            t = t_idx * dt
            t_tensor = torch.ones_like(x[:,0:1]) * t
            w = w_net(t_tensor, x)
            div = divergence(w, x, d)
            grad_div = grad(div, x, grad_outputs=torch.ones_like(div), retain_graph=True)[0]
            Jcbi = Jacobian(w, x, d, d) # Nx x d x d
            Jwz = (Jcbi.transpose(1,2) @ z.unsqueeze(-1)).squeeze(-1) # Nx x d
            loss_cost = loss_cost + 0.5* torch.mean(torch.sum((w+gamma*z)**2,dim=-1)) * dt
            # forward Euler scheme
            x = x + w * dt
            z = z - (Jwz + grad_div) * dt
        loss_cost = loss_cost + torch.mean(g(x))
        loss_cost.backward()
        optimizer.step()
        scheduler.step()
        if step % logging_freq == 0:
            x_unif = sample_x0_unif(10000,d)
            x_unif_np = x_unif.detach().cpu().numpy()
            w0_ref = w0(x_unif_np)
            w0_NN = w_net(torch.zeros_like(x_unif[:,0:1]), x_unif).detach().cpu().numpy()
            wT_ref = wT(x_unif_np)
            wT_NN = w_net(torch.ones_like(x_unif[:,0:1]), x_unif).detach().cpu().numpy()
            error_w0 = np.mean(np.abs(w0_ref - w0_NN))
            error_wT = np.mean(np.abs(wT_ref - wT_NN))
            x_valid = x_unif
            x_valid.requires_grad_(True)
            y_valid = rho0(x_valid,d)
            z_valid = score0(x_valid)
            for t_idx in range(Nt):
                t = t_idx * dt
                t_tensor = torch.ones_like(x_valid[:,0:1]) * t
                w = w_net(t_tensor, x_valid)
                div = divergence(w, x_valid, d)
                grad_div = grad(div, x_valid, grad_outputs=torch.ones_like(div), retain_graph=True)[0]
                Jcbi_valid = Jacobian(w, x_valid, d, d) # Nx x d x d
                Jwz_valid = (Jcbi_valid.transpose(1,2) @ z_valid.unsqueeze(-1)).squeeze(-1) # Nx x d
                y_valid = y_valid - div * y_valid * dt
                z_valid = z_valid - (Jwz_valid + grad_div) * dt
                x_valid = x_valid + w * dt
            x_valid = x_valid.detach().cpu().numpy()
            y_valid = y_valid.detach().cpu().numpy()
            z_valid = z_valid.detach().cpu().numpy()
            rhoT_ref = rhoT_interp(x_valid)
            scoreT_ref = scoreT(x_valid)
            error_densityT = np.mean(np.abs(rhoT_ref - y_valid))
            error_score = np.mean(np.abs(scoreT_ref - z_valid))
            time_elapsed = time.time() - start_time
            current_error = np.array([step,loss_cost.item(),error_score,error_w0,error_wT,error_densityT,time_elapsed])
            errors = current_error.reshape(1,-1) if step == 0 else np.concatenate((errors,current_error.reshape(1,-1)), axis=0)
            if args.verbose:
                print("step %d, loss %.3f, errors:  %.3f,  %.3f, %.3f,  %.3f, time: %d" % 
                      (step, loss_cost.item(), error_score, error_w0, error_wT, error_densityT, time_elapsed))            
    x_scatter = x.detach().cpu().numpy()

    if args.save_var:
        saved_results = {
            'network': w_net.state_dict(),
            'config': config,
            'xT': x_scatter,
            'errors': errors
        }
        torch.save(saved_results, './results/WPONG/WPONG'+str(d)+'d' + str(args.random_seed) + infix+'.pt')
  
if __name__ == '__main__':
    main()