# Flow matching for double moon
import numpy as np
import torch
import torch.nn as nn
import json
from torch.autograd import grad
import matplotlib.pyplot as plt
import argparse
import network
import time
import os

dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/FMDM2d.json", type=str)
    parser.add_argument("--save_var", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument("--random_seed", default=0, type=int)
    parser.add_argument('--verbose', default=True, action=argparse.BooleanOptionalAction)
    parser.add_argument("--net", default=None, type=str)
    parser.add_argument("--T", default=None, type=float)
    parser.add_argument("--retrain", default=None, type=str)
    # command for no verbose: --no-verbose
    # command for retrain: --retrain 0
    args = parser.parse_args()
    return args

def divergence(y, x, d):
    # compute the divergence of y w.r.t. x, both are N x d, output is N x 1
    div = 0
    for i in range(d):
        one = torch.ones_like(y[:,0])
        grad_y = grad(y[:,i], x, grad_outputs=one, retain_graph=True,create_graph=True)[0]
        div = div + grad_y[:,i:i+1]
    return div

def Jacobian(y, x, dy, dx):
    # compute the Jacobian of y w.r.t. x, y is N x dy, x is N x dx, output is N x dy x dx
    J = torch.zeros(y.size(0),dy,dx, dtype=dtype, device=device)
    one = torch.ones_like(y[:,0])
    for i in range(dy):
        grad_y = grad(y[:,i], x, grad_outputs=one, retain_graph=True,create_graph=True)[0] # N x dx
        J[:,i,:] = grad_y
    return J

def V(x): # potential function, x is N x 2
    x1 = x[:,0:1] # N x 1
    x_norm = torch.linalg.vector_norm(x,dim=-1,keepdim=True) # N x 1
    temp = torch.log(torch.exp(-2*(x1-3)**2) + torch.exp(-2*(x1+3)**2)) # N x 1
    return 2*(x_norm-3)**2 - 2*temp # N x 1

def V_x(x): # gradient of V
    x1 = x[:,0:1] # N x 1
    x_norm = torch.linalg.vector_norm(x,dim=-1,keepdim=True) # N x 1
    temp1 = torch.exp(-2*(x1-3)**2) # N x 1
    temp2 = torch.exp(-2*(x1+3)**2) # N x 1
    temp = 4*((x1-3)*temp1 + (x1+3)*temp2) / (temp1 + temp2) # N x 1
    return 4*(x_norm-3)/x_norm * x + torch.cat([temp,torch.zeros_like(x1)],dim=-1) # N x 2

def main():
    start_time = time.time()
    args = parse_args()
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    with open(args.config) as f:
        config = json.load(f)
    
    d = 2
    Nx = config['Nx']
    N_valid = config['N_valid']
    gamma = config['gamma']
    T = config['T']
    Nt = config['Nt']
    lr = config['lr']
    n_step = config['n_step']
    milestones = config['milestones']
    decay = config['decay']
    net_width = config['net_width']
    logging_freq = config['logging_freq']
    if args.T is not None:
        T = args.T
        config['T'] = T
    dt = T / Nt

    if args.net is not None:
        config["net_type"] = args.net
    infix = config["net_type"]
    w_net = getattr(network,config["net_type"])(d, d, net_width)
    w_net.type(dtype).to(device)

    if args.retrain is None:
        retrain = False
    else:
        retrain = True
        load_dir = os.path.join('./results/FM/FMDB2d'+str(T)+'T'+args.retrain+'.pt')
        results = torch.load(load_dir)
        w_net.load_state_dict(results['network'])
        x_start = results['xL']
        n_step = 200
        config['n_step'] = n_step
        
    if retrain:
        optimizer = torch.optim.Adam(w_net.parameters(), lr=lr, weight_decay=1e-1)
    else:
        optimizer = torch.optim.Adam(w_net.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=decay)

    # ==================== training ====================
    for step in range(n_step+1):
        optimizer.zero_grad()
        if retrain:
            x = x_start
        else:
            x = torch.randn(Nx, d, dtype=dtype, device=device)
        x.requires_grad_(True)
        z = -x # initial score
        loss_cost = 0
        for t_idx in range(Nt):
            t = t_idx * dt 
            t_tensor = torch.ones_like(x[:,0:1]) * t
            w = w_net(t_tensor, x)
            div = divergence(w, x, d)
            grad_div = grad(div, x, grad_outputs=torch.ones_like(div), retain_graph=True)[0]
            loss_cost = loss_cost + 0.5* torch.mean(torch.sum((w+gamma*z+V_x(x))**2,dim=-1)) * dt
            Jcbi = Jacobian(w, x, d, d) # Nx x d x d
            Jwz = (Jcbi.transpose(1,2) @ z.unsqueeze(-1)).squeeze(-1) # Nx x d
            # Jwz2 = grad(zw, x, grad_outputs=torch.ones_like(zw), retain_graph=True)[0] # Nx x d
            x = x + w * dt
            z = z - (Jwz + grad_div) * dt
            # z = z - (Jwz2 + grad_div) * dt
        loss_cost.backward()
        optimizer.step()
        scheduler.step()
        if step % logging_freq == 0:
            time_elapsed = time.time() - start_time
            current_error = np.array([step,loss_cost.item(),time_elapsed])
            errors = current_error.reshape(1,-1) if step == 0 else np.concatenate((errors,current_error.reshape(1,-1)), axis=0)
            if args.verbose:
                print('step:', step, 'loss:', loss_cost.item())

    if args.save_var:
        saved_results = {
            'network': w_net.state_dict(),
            'errors': errors,
            'config': config,
            'xT': x.detach()
        }
        if retrain:
            torch.save(saved_results, './results/FM/FMDB2d'+str(2*T)+'Tretrain'+str(args.random_seed)+'.pt')
        else:
            torch.save(saved_results, './results/FM/FMDB2d'+str(T)+'T'+str(args.random_seed)+'.pt')

if __name__ == '__main__':
    main()