import numpy as np
import torch
import json
import argparse
from torch.nn import Parameter

# Flow matching with regularization
# The state is an Ornstein-Uhlenbeck process with drift -ax and diffusion (2*gamma)**0.5

dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
T = 1.0

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/FMregFD1d.json", type=str)
    parser.add_argument("--save_var", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument("--random_seed", default=0, type=int)
    parser.add_argument('--verbose', default=True, action=argparse.BooleanOptionalAction)
    # command for no verbose: --no-verbose
    parser.add_argument("--reg", default=None, type=float)
    args = parser.parse_args()
    return args

def main():
    args = parse_args()
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    with open(args.config) as f:
        config = json.load(f)

    d = config['d']
    gamma = config['gamma']
    a = config['a']
    mu0 = config['mu0']
    Sigma0 = config['Sigma0']
    sigma0 = Sigma0 ** 0.5
    stable_var = gamma / a
    Nx = config['Nx']
    Nt = config['Nt']
    lr = config['lr']
    n_step = config['n_step']
    milestones = config['milestones']
    decay = config['decay']
    logging_freq = config['logging_freq']
    reg = config['reg']
    if args.reg is not None:
        reg = args.reg
        config['reg'] = reg
    print('regularization weight', reg)
    infix = str(reg)

    def mu(t):
        return mu0 * np.exp(-a*t)
    
    def Sigma(t):
        return (Sigma0 - stable_var) * np.exp(-2*a*t) + stable_var
    
    def rho(t,x): # normal distribution with mean mu(t) and variance Sigma(t)
        coef = (2*np.pi*Sigma(t)) ** (-d/2)
        return coef * np.exp(-np.sum((x-mu(t))**2,axis=-1,keepdims=True) / (2*Sigma(t)))
    
    # def score(t, x):
    #     return -(x - mu(t)) / Sigma(t)
    
    dt = T / Nt
    t_list = torch.linspace(0,T,Nt+1,device=device,dtype=dtype) # Nt+1
    t_list_np = np.linspace(0,T,Nt+1)
    mu_t = mu0 * torch.exp(-a*t_list) # Nt+1
    Sigma_t = (Sigma0 - stable_var) * torch.exp(-2*a*t_list) + stable_var # Nt+1

    A_true = torch.eye(d, dtype=dtype, device=device).unsqueeze(0).repeat(Nt+1,1,1) \
             * gamma / Sigma_t.reshape(-1,1,1) # Nt+1 x d x d
    B_true = (-gamma * mu_t / Sigma_t).reshape(-1,1).repeat(1,d) # Nt+1 x d
    C_true = (gamma*d/2 * torch.log(2*torch.pi*Sigma_t) + gamma/2 * mu_t**2 / Sigma_t).reshape(-1,1) # Nt+1 x 1

    A = Parameter(torch.randn(Nt+1,d,d,dtype=dtype,device=device, requires_grad=True))
    B = Parameter(torch.randn(Nt+1,d,dtype=dtype,device=device, requires_grad=True))
    C = Parameter(torch.randn(Nt+1,1,dtype=dtype,device=device, requires_grad=True))

    optimizer = torch.optim.Adam([A,B,C], lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=decay)
    odd_index = torch.arange(1,Nt,2,device=device)

    # ==================== training ====================
    for step in range(n_step+1):
        optimizer.zero_grad()
        x = mu0 + torch.randn(Nx,d,device=device,dtype=dtype) * sigma0 # Nx x d
        x.requires_grad_(True)
        z = -(x - mu0) / Sigma0 # Nx x d
        Q = (-torch.eye(d, dtype=dtype, device=device)/Sigma0).repeat(Nx,1,1) # Nx x d x d
        # no regularization at t=0
        Ai_sym = 0.5*(A[0] + A[0].transpose(0,1))
        w = torch.matmul(x,Ai_sym) + B[0] - a*x # Nx x d
        QA = torch.matmul(Q,Ai_sym) # Nx x d x d
        loss_cost = 0.5* torch.mean(torch.sum((w + gamma*z + a*x)**2,dim=-1)) * dt
        loss_reg = 0
        x = x + w * dt
        z = z + (a*z - torch.matmul(z,Ai_sym)) * dt
        Q = Q - (QA + QA.transpose(1,2) - 2*a*Q) * dt
        for i in range(1,Nt):
            Ai_sym = 0.5*(A[i] + A[i].transpose(0,1))
            Aix = torch.matmul(x,Ai_sym)
            w = Aix + B[i] - a*x
            loss_cost += 0.5* torch.mean(torch.sum((w + gamma*z + a*x)**2,dim=-1)) * dt
            temp = torch.matmul(x,A[i+1] - A[i-1])
            psi_t = (torch.sum((0.5*temp + B[i+1]-B[i-1])*x, dim=-1, keepdim=True) + C[i+1] - C[i-1]) / (2*dt) # Nx x 1
            Delta_rho = torch.einsum('ijj->i',Q).unsqueeze(1) # Nx x 1, trace of Q
            grad_psi = Aix + B[i] # Nx x d
            term23 = torch.sum(grad_psi * (0.5*grad_psi - a*x), dim=-1, keepdim=True) # Nx x 1
            loss_reg += torch.mean(torch.abs(psi_t + term23 + gamma*d*a + gamma**2 * (Delta_rho\
                            + 0.5 * torch.sum(z**2, dim=-1, keepdim=True)))) * dt
            x = x + w * dt
            z = z + (a*z - torch.matmul(z,Ai_sym)) * dt
            QA = torch.matmul(Q,Ai_sym)
            Q = Q - (QA + QA.transpose(1,2) - 2*a*Q) * dt
        loss = loss_cost + reg * loss_reg
        loss.backward()
        optimizer.step()
        scheduler.step()

        if step % logging_freq == 0:
            A_sym = 0.5*(A + A.transpose(1,2))
            error_A = torch.mean(torch.norm(A_sym - A_true, dim=(1, 2), p='fro')).item()
            error_B = torch.mean(torch.norm(B - B_true, dim=1)).item()
            current_error = np.array([step,loss.item(), loss_cost.item(), loss_reg.item(), error_A, error_B])
            errors = current_error.reshape(1,-1) if step == 0 else np.concatenate((errors,current_error.reshape(1,-1)), axis=0)
            if args.verbose:
                print('step %d, loss %.3e, loss_cost %.3e, loss_reg %.3e, A %.3e, B %.3e' % 
                      (step, loss.item(), loss_cost.item(), loss_reg.item(), error_A, error_B))
    if args.save_var:
        saved_results = {
            'A': A,
            'B': B,
            'C': C,
            'errors': errors,
            'config': config
        }
        torch.save(saved_results, './results/FM/FMregFD'+str(d)+'d'+infix+'seed'+str(args.random_seed)+'.pt')
 
if __name__ == '__main__':
    main()