import numpy as np
import torch
from torch import nn, optim
import argparse 
import sys
from tqdm import tqdm
import geomloss
from torchdiffeq import odeint 
import matplotlib.pyplot as plt
import math
import random
import os

import sys
sys.path.append("../../src")
import models, train, utils

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default="data.pkl")
parser.add_argument('--T', type=int, default=5)
parser.add_argument('--suffix', type=str, default = '')
parser.add_argument('--outdir', type=str, default = ".")
parser.add_argument('--figdir', type=str, default = ".")
parser.add_argument('--device', type=str, default = "cuda:0")
parser.add_argument('--hidden_sizes_score', nargs='+', type=int, default=[64, 64, 64])
# UPFI
parser.add_argument('--train_mult_upfi', action='store_true')
parser.add_argument('--train_mult_upfi_batch', type=int, default=64)
parser.add_argument('--train_mult_upfi_iters', type=int, default=1_000)
parser.add_argument('--train_mult_upfi_lr', type=float, default=3e-3)
parser.add_argument('--train_mult_upfi_teacherforcing_its', type=int, default=0)
parser.add_argument('--train_mult_upfi_ckpt', type=int, default=1000)
parser.add_argument('--train_mult_upfi_reg', type=str, default="vf")
parser.add_argument('--hidden_sizes_mult_upfi', nargs='+', type = int, default = [64, 64, 64])
# PFI 
parser.add_argument('--train_mult_pfi', action='store_true')
parser.add_argument('--train_mult_pfi_batch', type=int, default=64)
parser.add_argument('--train_mult_pfi_iters', type=int, default=1_000)
parser.add_argument('--train_mult_pfi_lr', type=float, default=3e-3)
parser.add_argument('--train_mult_pfi_teacherforcing_its', type=int, default=0)
parser.add_argument('--train_mult_pfi_ckpt', type=int, default=1000)
parser.add_argument('--train_mult_pfi_reg', type=str, default="vf")
parser.add_argument('--hidden_sizes_mult_pfi', nargs='+', type=int, default=[64, 64, 64])
# other 
parser.add_argument('--print_iter', type=int, default=100)
parser.add_argument('--reach', type=float, default=5.)
parser.add_argument('--score_logsigma_min', type=float, default=-2)
parser.add_argument('--score_logsigma_max', type=float, default=0)
parser.add_argument('--score_logsigma_steps', type=float, default=5)
parser.add_argument('--sigma_anneal_iters', type=int, default=2000)
parser.add_argument('--alpha_wfr', type=float, default=0.1)
parser.add_argument('--reg_wfr', type=float, default=0.001)
parser.add_argument('--D', type=float, default=0.25)
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()


random.seed(args.seed)
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

device = torch.device(args.device)
torch.set_default_dtype(torch.float32)

data = torch.load(args.data, weights_only = False)
dim = data['x'].shape[1]

# Score fitting 
X = [torch.tensor(data['x'][data['t_idx'] == i, :], device = device, dtype = torch.float32) for i in np.sort(np.unique(data['t_idx']))]
ts = torch.linspace(0, data['t_final'], len(X), device = device)

s = models.NCScoreFunc(d = dim, hidden_sizes = args.hidden_sizes_score, activation = torch.nn.ReLU, time_dependent = True).to(device)
sigmas = torch.linspace(args.score_logsigma_max, args.score_logsigma_min, args.score_logsigma_steps, device = device).exp()
s.load_state_dict(torch.load(os.path.join(args.outdir, f"params_NCScoreFunc_{args.suffix}_final.pt")))

odeint_options = {'method' : 'euler', 'options' : {'step_size' : 0.1}}
samplesloss_options = {'loss' : 'sinkhorn', 'reach' : args.reach}
dt = ts[1]-ts[0]
s = s.to(device)
m_ratios = [x.shape[0] / X[0].shape[0] for x in X]

## other parameters
D = args.D
T = len(X)

random.seed(args.seed)
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)


# MultUPFI
v_mult_upfi = models.MultiplicativeNoiseFlowGrowth(dim,
                                                   lambda t, x, sigma: s(t, x, sigma),
                                                   D,
                                                   sigmas[0],
                                                   kwargs_u = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : args.hidden_sizes_mult_upfi},
                                                   kwargs_v = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : args.hidden_sizes_mult_upfi}).to(device)

if args.train_mult_upfi:
    opt_mult_upfi = optim.AdamW([{'params' : v_mult_upfi.u.parameters(), 'lr' : args.train_mult_upfi_lr},
                                {'params' : v_mult_upfi.v.parameters(), 'lr' : args.train_mult_upfi_lr},
                                {'params' : v_mult_upfi.g.parameters(), 'lr' : args.train_mult_upfi_lr}])
    # sched_mult_upfi = StepLR(opt_mult_upfi, step_size=500, gamma=0.5)
    print("Training MULT_UPFI")
    trace_mult_upfi = train.train_multinoise_upfi(v_mult_upfi, opt_mult_upfi, s, sigmas,
                                  {'X' : X, 't' : ts, 'm_ratios' : torch.tensor(m_ratios).to(device), 'dt' : dt},
                                  {'D' : D, 'alpha_wfr' : args.alpha_wfr, 'reg_wfr' : args.reg_wfr},
                                  args.train_mult_upfi_batch,
                                  options = {'iters' : args.train_mult_upfi_iters,
                                             'print_iter' : args.print_iter,
                                             'checkpoint_iter' : args.train_mult_upfi_ckpt,
                                             'checkpoint_file' : f"params_MULT_UPFI_ODEFlowGrowth_{args.suffix}",
                                             'save_final' : True,
                                             'save_file' : f"params_MULT_UPFI_ODEFlowGrowth_{args.suffix}",
                                             'anneal_sigma_iters' : args.sigma_anneal_iters,
                                             'reg_kind' : args.train_mult_upfi_reg, 
                                             'outdir' : args.outdir},
                                  sample_batch_options = {'replacement' : True},
                                  odeint_options = odeint_options,
                                  samplesloss_options = samplesloss_options)
    plt.figure(figsize = (3, 3))
    plt.plot(trace_mult_upfi); plt.yscale('log'); plt.xlabel("Iteration"); plt.ylabel("Loss")
    plt.title("MULT_UPFI"); plt.savefig(os.path.join(args.figdir, f"trace_MULT_UPFI_ODEFlowGrowth_{args.suffix}.pdf"))
else:
    pass

random.seed(args.seed)
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

# PFI 
v_mult_pfi = models.MultiplicativeNoiseFlow(dim,
                                            lambda t, x, sigma: s(t, x, sigma),
                                            D,
                                            sigmas[0],
                                            kwargs_u = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : args.hidden_sizes_mult_pfi},
                                            kwargs_v = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : args.hidden_sizes_mult_pfi}
                                            ).to(device)
if args.train_mult_pfi:
    print("Training MULT_PFI")
    opt_mult_pfi = optim.AdamW(v_mult_pfi.parameters(), lr = args.train_mult_pfi_lr)
    trace_mult_pfi = train.train_multinoise_pfi(v_mult_pfi, opt_mult_pfi, s, sigmas,
                                  {'X' : X, 't' : ts, 'dt' : dt},
                                  {'D' : D, 'reg' : args.reg_wfr},
                                  args.train_mult_pfi_batch,
                                  options = {'iters' : args.train_mult_pfi_iters,
                                             'print_iter' : args.print_iter,
                                             'checkpoint_iter' : args.train_mult_pfi_ckpt,
                                             'checkpoint_file' : f"params_MULT_PFI_VectorField_{args.suffix}",
                                             'save_final' : True,
                                             'save_file' : f"params_MULT_PFI_VectorField_{args.suffix}",
                                             'anneal_sigma_iters' : args.sigma_anneal_iters,
                                             'reg_kind' : args.train_mult_pfi_reg, 
                                             'outdir' : args.outdir},
                                  sample_batch_options = {'replacement' : True},
                                  odeint_options = odeint_options)
    plt.figure(figsize = (3, 3))
    plt.plot(trace_mult_upfi); plt.yscale('log'); plt.xlabel("Iteration"); plt.ylabel("Loss")
    plt.title("MULT_PFI"); plt.savefig(os.path.join(args.figdir, f"trace_MULT_PFI_VectorField_{args.suffix}.pdf"))
else:
    pass
