import numpy as np
import torch
import torch.nn as nn
import json
from torch.autograd import grad
import matplotlib.pyplot as plt
import argparse
import network
import pickle

dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/WPONG10d.json", type=str)
    parser.add_argument("--save_var", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument("--random_seed", default=0, type=int)
    # parser.add_argument("--dim", default=2, type=int)
    parser.add_argument('--verbose', default=True, action=argparse.BooleanOptionalAction)
    parser.add_argument("--net", default=None, type=str)
    # command for no verbose: --no-verbose
    args = parser.parse_args()
    return args

def divergence(y, x, d):
    # compute the divergence of y w.r.t. x, both are N x d, output is N x 1
    div = 0
    for i in range(d):
        one = torch.ones_like(y[:,0])
        grad_y = grad(y[:,i], x, grad_outputs=one, retain_graph=True,create_graph=True)[0]
        div = div + grad_y[:,i:i+1]
    return div

def Jacobian(y, x, dy, dx):
    # compute the Jacobian of y w.r.t. x, y is N x dy, x is N x dx, output is N x dy x dx
    J = torch.zeros(y.size(0),dy,dx, dtype=dtype, device=device)
    one = torch.ones_like(y[:,0])
    for i in range(dy):
        grad_y = grad(y[:,i], x, grad_outputs=one, retain_graph=True,create_graph=True)[0] # N x dx
        J[:,i,:] = grad_y
    return J

def sample_x0(N,d): # standard normal distribution
    return torch.randn(N, d, dtype=dtype, device=device)

def sample_x0_unif(N,d): # uniform sample in a disk with radius 1
    x = torch.randn(N, d, dtype=dtype, device=device)
    x = x / torch.sqrt(torch.sum(x**2,dim=-1,keepdim=True))
    r = torch.rand(N,1, dtype=dtype, device=device) ** (1/d)
    x = r * x
    return x

def rho0(x,d): # initial density N(0,1)
    return torch.exp(-torch.sum(x**2,dim=-1,keepdim=True)/2)/((2*np.pi)**(d/2))

def score0(x): # initial score
    return -x

def main():
    args = parse_args()
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    with open(args.config) as f:
        config = json.load(f)

    d = config['d']
    gamma = config['gamma']
    Nx = config['Nx']
    T = config['T']
    Nt = config['Nt']
    lr = config['lr']
    n_step = config['n_step']
    milestones = config['milestones']
    decay = config['decay']
    net_width = config['net_width']
    logging_freq = config['logging_freq']
    a = config['a']
    dt = T / Nt

    c1 = torch.tensor([[a,a,0,0,0,0,0,0,0,0]], dtype=dtype, device=device)
    c2 = torch.tensor([[-a,-a,0,0,0,0,0,0,0,0]], dtype=dtype, device=device)
    coef = 4

    def g(x): # terminal cost
        # input x is N x 2, output is N x 1
        part1 = torch.sum((x - c1)**2, dim=-1, keepdim=True)
        part2 = torch.sum((x - c2)**2, dim=-1, keepdim=True)
        return part1 * part2 / coef
    
    infix = ''
    if args.net is not None:
        config["net_type"] = args.net
        infix = infix + args.net
    w_net = getattr(network,config["net_type"])(d, d, net_width)
    w_net.type(dtype).to(device)

    optimizer = torch.optim.Adam(w_net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=decay)

    # ==================== training ====================
    for step in range(n_step+1):
        optimizer.zero_grad()
        x = sample_x0(Nx,d)
        x.requires_grad_(True)
        z = score0(x)
        loss_cost = 0
        for t_idx in range(Nt):
            t = t_idx * dt
            t_tensor = torch.ones_like(x[:,0:1]) * t
            w = w_net(t_tensor, x)
            div = divergence(w, x, d)
            grad_div = grad(div, x, grad_outputs=torch.ones_like(div), retain_graph=True)[0]
            Jcbi = Jacobian(w, x, d, d) # Nx x d x d
            Jwz = (Jcbi.transpose(1,2) @ z.unsqueeze(-1)).squeeze(-1) # Nx x d
            loss_cost = loss_cost + 0.5* torch.mean(torch.sum((w+gamma*z)**2,dim=-1)) * dt
            # forward Euler scheme
            x = x + w * dt
            z = z - (Jwz + grad_div) * dt
        loss_cost = loss_cost + torch.mean(g(x))
        loss_cost.backward()
        optimizer.step()
        scheduler.step()
        if args.verbose and step % logging_freq == 0:
            print("step %d, loss %f" % (step, loss_cost.item()))
    x_scatter = x.detach().cpu().numpy()

    
    if args.save_var:
        saved_results = {
            'network': w_net.state_dict(),
            'config': config,
            'xT': x_scatter
        }
        torch.save(saved_results, './results/WPONG/WPONG'+str(d)+'d' + infix + str(args.random_seed)+'.pt')
    
if __name__ == '__main__':
    main()