import numpy as np
import torch
import json
import argparse
from torch.nn import Parameter

dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
T = 1.0

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/WPOregFD1d.json", type=str)
    parser.add_argument("--save_var", default=False, action=argparse.BooleanOptionalAction)
    parser.add_argument("--random_seed", default=0, type=int)
    parser.add_argument('--verbose', default=True, action=argparse.BooleanOptionalAction)
    # command for no verbose: --no-verbose
    parser.add_argument("--reg", default=None, type=float)
    args = parser.parse_args()
    return args

def sample_x0(N,d):
    return torch.randn(N, d, dtype=dtype, device=device) * np.sqrt(2*(T+1))

def main():
    args = parse_args()
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    with open(args.config) as f:
        config = json.load(f)

    d = config['d']
    gamma = config['gamma']
    Nx = config['Nx']
    Nt = config['Nt']
    lr = config['lr']
    n_step = config['n_step']
    milestones = config['milestones']
    decay = config['decay']
    logging_freq = config['logging_freq']
    reg = config['reg']
    if args.reg is not None:
        reg = args.reg
        config['reg'] = reg
    print('regularization weight', reg)
    infix = str(reg)

    dt = T / Nt
    t_list = torch.linspace(0,T,Nt+1,device=device,dtype=dtype)

    def rho(t,x):
        coef = (4 * np.pi * (T-t+1) * gamma) ** (-d/2)
        return coef * torch.exp(-torch.sum(x**2,dim=-1,keepdim=True) / (4*gamma*(T-t+1)))

    def w_true(t,x):
        return -x / (2*(T-t+1))

    def w_div(t,x):
        return -d / (2*(T-t+1))

    def score(t, x):
        return -x / (2*(T-t+1))
    
    def velocity_true(t,x):
        return -x / (T-t+1)
    
    def velocity_div(t,x):
        return -d / (T-t+1)
    
    A_true = -torch.eye(d, dtype=dtype, device=device).unsqueeze(0).repeat(Nt+1,1,1) \
        / (2*(T-t_list+1).unsqueeze(-1).unsqueeze(-1)) # Nt+1 x d x d
    B_true = torch.zeros(Nt+1,d, dtype=dtype, device=device) # Nt+1 x d
    C_true = gamma*d*torch.log(4*np.pi*gamma/(T-t_list+1))/2 # Nt+1
    odd_index = torch.arange(1,Nt,2,device=device)
    even_index = torch.arange(0,Nt+1,2,device=device)

    A = Parameter(torch.randn(Nt+1,d,d,dtype=dtype,device=device, requires_grad=True))
    B = Parameter(torch.randn(Nt+1,d,dtype=dtype,device=device, requires_grad=True))
    C = Parameter(torch.randn(Nt+1,1,dtype=dtype,device=device, requires_grad=True))

    optimizer = torch.optim.Adam([A,B,C], lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=decay)

    # ==================== training ====================
    for step in range(n_step+1):
        optimizer.zero_grad()
        x = sample_x0(Nx,d) # Nx x d
        x.requires_grad_(True)
        z = score(0,x) # Nx x d
        Q = (-torch.eye(d, dtype=dtype, device=device)/(2*gamma*(T+1))).repeat(Nx,1,1) # Nx x d x d
        # no regularization at t=0
        Ai_sym = 0.5*(A[0] + A[0].transpose(0,1)) # d x d
        w = torch.matmul(x,Ai_sym) + B[0] # Nx x d
        QA = torch.matmul(Q,Ai_sym) # Nx x d x d
        loss_cost = 0.5* torch.mean(torch.sum((w + gamma*z)**2,dim=-1)) * dt
        x = x + w*dt
        z = z - torch.matmul(z,Ai_sym)*dt
        Q = Q - dt * (QA + QA.transpose(1,2))
        loss_reg = 0
        for i in range(1,Nt):
            Ai_sym = 0.5*(A[i] + A[i].transpose(0,1))
            Aix = torch.matmul(x,Ai_sym)
            w = Aix + B[i]
            loss_cost += 0.5* torch.mean(torch.sum((w + gamma*z)**2,dim=-1)) * dt
            temp = torch.matmul(x,A[i+1] - A[i-1])
            psi_t = (torch.sum((0.5*temp + B[i+1]-B[i-1])*x, dim=-1, keepdim=True) + C[i+1] - C[i-1]) / (2*dt)
            Delta_rho = torch.einsum('ijj->i',Q).unsqueeze(1) # Nx x 1, trace of Q
            loss_reg += torch.mean(torch.abs(psi_t + 0.5* torch.sum(w**2,dim=-1,keepdim=True) 
                        + gamma**2 * (Delta_rho + 0.5 * torch.sum(z**2,dim=-1,keepdim=True)))) * dt
            x = x + w * dt
            z = z - torch.matmul(z,Ai_sym)*dt
            QA = torch.matmul(Q,Ai_sym)
            Q = Q - dt * (QA + QA.transpose(1,2))
        loss_cost = loss_cost + torch.mean(torch.sum(x**2,dim=-1))/2 # terminal cost
        loss = loss_cost + reg * loss_reg
        loss.backward()
        optimizer.step()
        scheduler.step()

        if step % logging_freq == 0:
            A_sym = 0.5*(A + A.transpose(1,2))
            error_A = torch.mean(torch.norm(A_sym - A_true, dim=(1, 2), p='fro')).item()
            error_B = torch.mean(torch.norm(B - B_true, dim=1)).item()
            current_error = np.array([step,loss.item(), loss_cost.item(), loss_reg.item(), error_A, error_B])
            errors = current_error.reshape(1,-1) if step == 0 else np.concatenate((errors,current_error.reshape(1,-1)), axis=0)
            if args.verbose:
                print('step %d, loss %.3e, loss_cost %.3e, loss_reg %.3e, A %.3e, B %.3e' % 
                      (step, loss.item(), loss_cost.item(), loss_reg.item(), error_A, error_B))
    if args.save_var:
        saved_results = {
            'A': A,
            'B': B,
            'C': C,
            'errors': errors,
            'config': config
        }
        torch.save(saved_results, './results/WPOreg/WPOregFD'+str(d)+'d'+infix+'seed'+str(args.random_seed)+'.pt')

if __name__ == '__main__':
    main()