import numpy as np
import torch
import torch.nn as nn
import json
from torch.autograd import grad
import matplotlib.pyplot as plt
import argparse
import network
import time

dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/WPONG1d.json", type=str)
    parser.add_argument("--save_var", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument("--random_seed", default=0, type=int)
    parser.add_argument('--verbose', default=True, action=argparse.BooleanOptionalAction)
    parser.add_argument("--net", default=None, type=str)
    # command for no verbose: --no-verbose
    args = parser.parse_args()
    return args

def divergence(y, x, d):
    # compute the divergence of y w.r.t. x, both are N x d, output is N x 1
    div = 0
    for i in range(d):
        one = torch.ones_like(y[:,0])
        grad_y = grad(y[:,i], x, grad_outputs=one, retain_graph=True,create_graph=True)[0]
        div = div + grad_y[:,i:i+1]
    return div

def Jacobian(y, x, dy, dx):
    # compute the Jacobian of y w.r.t. x, y is N x dy, x is N x dx, output is N x dy x dx
    J = torch.zeros(y.size(0),dy,dx, dtype=dtype, device=device)
    one = torch.ones_like(y[:,0])
    for i in range(dy):
        grad_y = grad(y[:,i], x, grad_outputs=one, retain_graph=True,create_graph=True)[0] # N x dx
        J[:,i,:] = grad_y
    return J

def sample_x0(N,d): # standard normal distribution
    return torch.randn(N, d, dtype=dtype, device=device)

def score0(x): # initial score
    return -x

def sample_x0_valid(N,d): # uniform sample in a disk with radius 1
    x = torch.randn(N, d, dtype=dtype, device=device)
    x = x / torch.sqrt(torch.sum(x**2,dim=-1,keepdim=True))
    r = torch.rand(N,1, dtype=dtype, device=device) ** (1/d)
    x = r * x
    return x

def g(x): # terminal cost double well
    return torch.sum((x**2-1)**2,dim=-1,keepdim=True)/4 # num_sample x 1

def main():
    start_time = time.time()
    args = parse_args()
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    with open(args.config) as f:
        config = json.load(f)

    d = config['d']
    gamma = config['gamma']
    Nx = config['Nx']
    N_valid = config['N_valid']
    T = config['T']
    Nt = config['Nt']
    lr = config['lr']
    n_step = config['n_step']
    milestones = config['milestones']
    decay = config['decay']
    net_width = config['net_width']
    logging_freq = config['logging_freq']
    if args.save_var:
        logging_freq = 1
    dt = T / Nt
    infix = ''
    if args.net is not None:
        config["net_type"] = args.net
        infix = args.net
    w_net = getattr(network,config["net_type"])(d, d, net_width)
    w_net.type(dtype).to(device)

    saved_net = getattr(network,config["net_type"])(d, d, net_width)
    saved_net.type(dtype).to(device)

    # # only the weights in the network has weight decay
    # param_groups = [
    #     {'params': [param for name, param in w_net.named_parameters() if 'bias' not in name], 'weight_decay': config['weight_decay']},
    #     {'params': [param for name, param in w_net.named_parameters() if 'bias' in name], 'weight_decay': 0}
    # ]
    # optimizer = torch.optim.Adam(param_groups, lr=lr)

    # # all the parameter in the network has weight decay
    # optimizer = torch.optim.Adam(w_net.parameters(), lr=lr)
    optimizer = torch.optim.Adam(w_net.parameters(), lr=lr, weight_decay=config['weight_decay'])
    
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=decay)

# ==================== compute errors for score function, density, and w ====================
    def V_x(x): # gradient of V
        return x**3 - x

    def rho0(x): # initial density N(0,1)
        return torch.exp(-torch.sum(x**2,dim=-1,keepdim=True)/2)/((2*np.pi)**(d/2))
    Ny = 1000
    Nz = 1001
    Nyt = 2002
    rangex = 2
    rangey = rangex + 2
    rangez = rangex + 2
    rangeyt = rangex + 4
    y_start = -rangey
    y_end = rangey
    z_start = -rangez
    z_end = rangez
    yt_start = -rangeyt
    yt_end = rangeyt
    y_grid = torch.linspace(y_start,y_end,Ny,dtype=dtype,device=device).view(-1,1) # Ny x 1
    z_grid = torch.linspace(z_start,z_end,Nz,dtype=dtype,device=device).view(-1,1) # Nz x 1
    yt = torch.linspace(yt_start,yt_end,Nyt,dtype=dtype,device=device).view(-1,1) # Nyt x 1
    dy = (y_end - y_start) / (Ny - 1)
    dz = (z_end - z_start) / (Nz - 1)
    dyt = (yt_end - yt_start) / (Nyt - 1)

    def fcn1(x,z,beta,tt):
        # x is Nx x d, z is Nz x d, output is Nx x Nz x 1
        z = z.unsqueeze(0) # 1 x Nz x d
        x = x.unsqueeze(1) # Nx x 1 x d
        x_z = x - z # Nx x Nz x d
        x_z2 = torch.sum(x_z**2,dim=-1,keepdim=True) # Nx x Nz x 1
        temp = g(z) + x_z2/(2*tt) # Nx x Nz x 1
        return torch.exp(-temp / (2*beta)) # Nx x Nz x 1

    def fcn2(x,y,beta,t):
        # x is Nx x d, y is Ny x d, output is Nx x Ny x 1
        x = x.unsqueeze(1) # Nx x 1 x d
        y = y.unsqueeze(0) # 1 x Ny x d
        x_y = x - y # Nx x Ny x d
        x_y2 = torch.sum(x_y**2,dim=-1,keepdim=True) # Nx x Ny x 1
        temp = torch.exp(-x_y2/(2*t) / (2*beta)) # Nx x Ny x 1
        return temp * rho0(y) # Nx x Ny x 1

    def rhoT(x): # terminal density
        temp1 = torch.sum(fcn1(y_grid,yt,gamma,T), dim=1) * dyt # Ny x 1
        temp2 = fcn1(y_grid,x,gamma,T) # Ny x Nx x 1
        return torch.sum(temp2 * (rho0(y_grid) / temp1).unsqueeze(1), dim=0) * dy # Nx x 1

    def rho(t,x):
        if t == 0:
            return rho0(x)
        elif t==T:
            return rhoT(x)
        temp1 = torch.sum(fcn1(y_grid,yt,gamma,T), dim=1) * dyt # Ny x 1
        temp2 = fcn2(x,y_grid,gamma,t) # Nx x Ny x 1
        part1 = torch.sum(temp2 / temp1.unsqueeze(0), dim=1) * dy # Nx x 1
        part2 = torch.sum(fcn1(x,z_grid,gamma,T-t), dim=1) * dz # Nx x 1
        coef = (4*np.pi*gamma*t*(T-t)/T)**(-d/2)
        return coef * part1 * part2 # Nx x 1
    
    def scoreT(x):
        x_y = x.unsqueeze(0) - y_grid.unsqueeze(1) # Ny x Nx x d
        temp1 = torch.sum(fcn1(y_grid,yt,gamma,T), dim=1) * dyt # Ny x 1
        temp2 = fcn1(y_grid,x,gamma,T) * (rho0(y_grid).unsqueeze(1)) # Ny x Nx x 1
        temp3 = (-1/(2*gamma)) * (V_x(x) + x_y/T) * temp2 # Ny x Nx x d
        numeritor = torch.sum(temp3 / temp1.unsqueeze(1), dim=0) # Nx x d
        denominator = torch.sum(temp2 / temp1.unsqueeze(1), dim=0)
        return numeritor / denominator
    
    def w_true0(x): # true w at t=0
        x_y = x.unsqueeze(1) - y_grid.unsqueeze(0) # Nx x Ny x d
        temp = fcn1(x,y_grid,gamma,T) # Nx x Ny x 1
        numeritor = - torch.sum(temp * x_y / T, dim=1) * dy # Nx x 1
        denominator = torch.sum(temp, dim=1) * dy
        return numeritor / denominator - gamma * score0(x)
    
    def w_trueT(x): # true w at t=T
        return -V_x(x) - gamma * scoreT(x)

    # x0_valid = sample_x0_valid(N_valid,d)
    x0_valid = torch.linspace(-1,1,N_valid,dtype=dtype,device=device).view(-1,1)
    w0_ref = w_true0(x0_valid)
    wT_ref = w_trueT(x0_valid)
    one = torch.ones_like(x0_valid[:,0:1])

    # ==================== training ====================
    loss_list = np.ones(n_step+1)*100
    best_loss = 100
    best_index = n_step
    for step in range(n_step+1):
        optimizer.zero_grad()
        x = sample_x0(Nx,d)
        x.requires_grad_(True)
        z = score0(x)
        loss_cost = 0
        for t_idx in range(Nt):
            t = t_idx * dt
            t_tensor = torch.ones_like(x[:,0:1]) * t
            w = w_net(t_tensor, x)
            div = divergence(w, x, d)
            grad_div = grad(div, x, grad_outputs=torch.ones_like(div), retain_graph=True)[0]
            loss_cost = loss_cost + 0.5* torch.mean(torch.sum((w+gamma*z)**2,dim=-1)) * dt
            # forward Euler scheme
            Jcbi = Jacobian(w, x, d, d) # Nx x d x d
            Jwz = (Jcbi.transpose(1,2) @ z.unsqueeze(-1)).squeeze(-1) # Nx x d
            # divz = div*z # Nx x d
            x = x + w * dt
            z = z - (Jwz + grad_div) * dt
        loss_cost = loss_cost + torch.mean(g(x))
        loss_list[step] = loss_cost.item()
        if step == 50:
            best_loss = loss_list[41:51].mean()
        if step > 50:
            ten_average = loss_list[step-9:step+1].mean()
            if ten_average < best_loss:
                best_loss = ten_average
                best_index = step
                saved_net.load_state_dict(w_net.state_dict())
        loss_cost.backward()
        optimizer.step()
        scheduler.step()
        if step % logging_freq == 0:
            w0_NN = w_net(0*one, x0_valid).detach()
            wT_NN = w_net(T*one, x0_valid).detach()
            error_w0 = torch.mean(torch.abs(w0_NN - w0_ref)).cpu().numpy()
            error_wT = torch.mean(torch.abs(wT_NN - wT_ref)).cpu().numpy()
            x = x0_valid.requires_grad_(True)
            y = rho0(x)
            z = score0(x)
            for t_idx in range(Nt):
                t = t_idx * dt
                t_tensor = one * t
                w = w_net(t_tensor, x)
                div = divergence(w, x, d)
                grad_div = grad(div, x, grad_outputs=torch.ones_like(div), retain_graph=True)[0]
                Jcbi = Jacobian(w, x, d, d)
                Jwz = (Jcbi.transpose(1,2) @ z.unsqueeze(-1)).squeeze(-1)
                x = x + w * dt
                y = y - div * y * dt
                z = z - (Jwz + grad_div) * dt
            y_ref = rhoT(x)
            z_ref = scoreT(x)
            error_densityT = torch.mean(torch.abs(y - y_ref)).detach().cpu().numpy()
            error_score = torch.mean(torch.abs(z - z_ref)).detach().cpu().numpy()
            time_elapsed = time.time() - start_time
            current_error = np.array([step,loss_cost.item(),error_score,error_w0,error_wT,error_densityT,time_elapsed])
            errors = current_error.reshape(1,-1) if step == 0 else np.concatenate((errors,current_error.reshape(1,-1)), axis=0)
            if args.verbose:
                print("step %d, loss %.3f, errors:  %.3f,  %.3f, %.3f,  %.3f, time: %d" % 
                      (step, loss_cost.item(), error_score, error_w0, error_wT, error_densityT, time_elapsed))

    # save the best network
    print("Best loss is %f at step %d" % (best_loss,best_index))
    w_net.load_state_dict(saved_net.state_dict())

    # if args.save_var:
    saved_results = {
        'network': w_net.state_dict(),
        'config': config,
        'errors': errors
    }
    torch.save(saved_results, './results/WPONG/WPONG'+str(d)+'d' + infix + str(args.random_seed)+'.pt')
    
if __name__ == '__main__':
    main()