import numpy as np
import torch
import torch.nn as nn
import json
from torch.autograd import grad
import matplotlib.pyplot as plt
import argparse
import network
import time

dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
T = 1.0

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/LQ1d.json", type=str)
    parser.add_argument("--save_var", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument("--random_seed", default=0, type=int)
    parser.add_argument('--verbose', default=True, action=argparse.BooleanOptionalAction)
    parser.add_argument("--net", default=None, type=str)
    # command for no verbose: --no-verbose
    args = parser.parse_args()
    return args

def divergence(y, x, d):
    # compute the divergence of y w.r.t. x, both are N x d, output is N x 1
    div = 0
    for i in range(d):
        one = torch.ones_like(y[:,0])
        grad_y = grad(y[:,i], x, grad_outputs=one, retain_graph=True,create_graph=True)[0]
        div = div + grad_y[:,i:i+1]
    return div

def Jacobian(y, x, dy, dx):
    # compute the Jacobian of y w.r.t. x, y is N x dy, x is N x dx, output is N x dy x dx
    J = torch.zeros(y.size(0),dy,dx, dtype=dtype, device=device)
    one = torch.ones_like(y[:,0])
    for i in range(dy):
        grad_y = grad(y[:,i], x, grad_outputs=one, retain_graph=True,create_graph=True)[0] # N x dx
        J[:,i,:] = grad_y
    return J

def sample_x0(N, d, alpha):
    return torch.randn(N, d, dtype=dtype, device=device) / np.sqrt(alpha)

def sample_x0_valid(N, d): # uniform sample in a disk with radius 1
    x = torch.randn(N, d, dtype=dtype, device=device)
    x = x / torch.sqrt(torch.sum(x**2,dim=-1,keepdim=True))
    r = torch.rand(N,1, dtype=dtype, device=device) ** (1/d)
    x = r * x
    return x

def main():
    start_time = time.time()
    args = parse_args()
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    with open(args.config) as f:
        config = json.load(f)
    
    d = config['d']
    Nx = config['Nx']
    N_valid = config['N_valid']
    Nt = config['Nt']
    gamma = config['gamma']
    lr = config['lr']
    n_step = config['n_step']
    milestones = config['milestones']
    decay = config['decay']
    net_width = config['net_width']
    logging_freq = config['logging_freq']
    alpha = (np.sqrt(gamma**2 + 4) - gamma) /2
    infix = ''
    if args.net is not None:
        config["net_type"] = args.net
        infix = args.net
    if args.save_var:
        logging_freq = 1

    dt = T / Nt

    def w_true(t,x):
        return 0

    def w_div(t,x):
        return 0

    def score(t, x):
        return -alpha * x

    def rho(t, x):
        return torch.exp(-torch.sum(x**2, dim=-1, keepdim=True)*(alpha/2)) * (alpha/2/np.pi)**(d/2)

    def log_rho(t, x):
        return -torch.sum(x**2, dim=-1, keepdim=True)*(alpha/2) + d/2 * np.log(alpha/2/np.pi)
    
    w_net = getattr(network,config["net_type"])(d, d, net_width)
    w_net.type(dtype).to(device)

    optimizer = torch.optim.Adam(w_net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=decay)

    x0_valid = sample_x0_valid(N_valid,d)
    w0_ref = w_true(0,x0_valid)
    wT_ref = w_true(T,x0_valid)
    one = torch.ones_like(x0_valid[:,0:1])

    # ==================== training ====================
    num_record = int(n_step / logging_freq) + 1
    training_curve = np.zeros((num_record, 6))
    idx_count = 0
    for step in range(n_step+1):
        optimizer.zero_grad()
        x = sample_x0(Nx,d,alpha)
        x.requires_grad_(True)
        y = log_rho(0,x)
        z = score(0,x)
        loss_cost = 0
        for t_idx in range(Nt):
            t = t_idx * dt
            t_tensor = torch.ones_like(x[:,0:1]) * t
            w = w_net(t_tensor, x)
            div = divergence(w, x, d)
            grad_div = grad(div, x, grad_outputs=torch.ones_like(div), retain_graph=True)[0]
            loss_cost = loss_cost +  torch.mean(0.5*torch.sum((w+z)**2+x**2,dim=-1) + gamma*y) * dt
            Jcbi = Jacobian(w, x, d, d) # Nx x d x d
            Jwz = (Jcbi.transpose(1,2) @ z.unsqueeze(-1)).squeeze(-1) # Nx x d

            # forward Euler scheme
            x = x + w * dt
            y = y - div * dt
            z = z - (Jwz + grad_div) * dt

        loss_cost = loss_cost + torch.mean(torch.sum(x**2,dim=-1))*(alpha/2)
        loss_cost.backward()
        optimizer.step()
        scheduler.step()

        if step % logging_freq == 0:
            w0_NN = w_net(0*one, x0_valid).detach()
            wT_NN = w_net(T*one, x0_valid).detach()
            error_w0 = torch.mean(torch.abs(w0_NN - w0_ref)).cpu().numpy()
            error_wT = torch.mean(torch.abs(wT_NN - wT_ref)).cpu().numpy()
            x = x0_valid.requires_grad_(True)
            y = rho(0,x) # this is the density not the log density
            z = score(0,x)
            for t_idx in range(Nt):
                t = t_idx * dt
                t_tensor = torch.ones_like(x[:,0:1]) * t
                w = w_net(t_tensor, x)
                div = divergence(w, x, d)
                grad_div = grad(div, x, grad_outputs=torch.ones_like(div), retain_graph=True)[0]
                Jcbi = Jacobian(w, x, d, d)
                Jwz = (Jcbi.transpose(1,2) @ z.unsqueeze(-1)).squeeze(-1)
                x = x + w * dt
                y = y - div * y * dt # this is the density not the log density, so the dynamic is different
                z = z - (Jwz + grad_div) * dt
                y_ref = rho(T,x)
            z_ref = score(T,x)
            error_densityT = torch.mean(torch.abs(y - y_ref)).detach().cpu().numpy()
            error_score = torch.mean(torch.abs(z - z_ref)).detach().cpu().numpy()
            time_elapsed = time.time() - start_time
            current_error = np.array([step,loss_cost.item(),error_score,error_w0,error_wT,error_densityT,time_elapsed])
            errors = current_error.reshape(1,-1) if step == 0 else np.concatenate((errors,current_error.reshape(1,-1)), axis=0)
            if args.verbose:
                print("step %d, loss %.3f, errors:  %.3f,  %.3f, %.3f,  %.3f, time: %d" % 
                      (step, loss_cost.item(), error_score, error_w0, error_wT, error_densityT, time_elapsed))
            
    if args.save_var:
        saved_results = {
            'network': w_net.state_dict(),
            'errors': errors,
            'config': config
        }
        torch.save(saved_results, './results/LQ/LQ'+str(d)+'d' + infix + str(args.random_seed)+'.pt')

if __name__ == '__main__':
    main()