import numpy as np
import torch
from torch import nn, optim
import argparse 
import sys
from tqdm import tqdm
import geomloss
from torchdiffeq import odeint 
import matplotlib.pyplot as plt
import math
import random
import os

import models, train, utils

def int_or_none(value):
    if value.lower() == 'none':
        return None
    try:
        return int(value)
    except ValueError:
        raise argparse.ArgumentTypeError(f"'{value}' is not an integer or 'none'")

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default="data.pkl")
parser.add_argument('--T', type=int, default=5)
parser.add_argument('--suffix', type=str, default = '')
parser.add_argument('--outdir', type=str, default = ".")
parser.add_argument('--figdir', type=str, default = ".")
parser.add_argument('--device', type=str, default = "cuda:0")
parser.add_argument('--train_score', action='store_true')
parser.add_argument('--train_score_batch', type=int, default=64)
parser.add_argument('--train_score_iters', type=int, default=10_000)
parser.add_argument('--train_score_lr', type=float, default=3e-3)
parser.add_argument('--train_score_ckpt', type=int, default=1000)
parser.add_argument('--hidden_sizes_score', nargs='+', type=int, default=[64, 64, 64])
# UPFI
parser.add_argument('--train_upfi', action='store_true')
parser.add_argument('--train_upfi_batch', type=int, default=64)
parser.add_argument('--train_upfi_g_type', type=str, default='aut')
parser.add_argument('--train_upfi_v_type', type=str, default='aut')
parser.add_argument('--train_upfi_iters', type=int, default=1_000)
parser.add_argument('--train_upfi_lr', type=float, default=3e-3)
parser.add_argument('--train_upfi_teacherforcing_its', type=int, default=0)
parser.add_argument('--train_upfi_ckpt', type=int, default=1000)
parser.add_argument('--train_upfi_reg', type=str, default='vf')
parser.add_argument('--train_upfi_resume', type=int, default=0)
parser.add_argument('--hidden_sizes_upfi', nargs='+', type = int, default = [64, 64, 64])
parser.add_argument('--hidden_sizes_upfi_g', nargs='+', type = int, default = [64, ])
# PFI 
parser.add_argument('--train_pfi', action='store_true')
parser.add_argument('--train_pfi_batch', type=int, default=64)
parser.add_argument('--train_pfi_v_type', type=str, default='nonaut')
parser.add_argument('--train_pfi_iters', type=int, default=1_000)
parser.add_argument('--train_pfi_lr', type=float, default=3e-3)
parser.add_argument('--train_pfi_teacherforcing_its', type=int, default=0)
parser.add_argument('--train_pfi_ckpt', type=int, default=1000)
parser.add_argument('--train_pfi_reg', type=str, default='vf')
parser.add_argument('--train_pfi_resume', type=int, default=0)
parser.add_argument('--hidden_sizes_pfi', nargs='+', type=int, default=[64, 64, 64])
# ODE
parser.add_argument('--train_ode', action='store_true')
parser.add_argument('--train_ode_batch', type=int, default=64)
parser.add_argument('--train_ode_iters', type=int, default=1_000)
parser.add_argument('--train_ode_lr', type=float, default=3e-3)
parser.add_argument('--train_ode_teacherforcing_its', type=int, default=0)
parser.add_argument('--train_ode_ckpt', type=int, default=1000)
parser.add_argument('--train_ode_resume', type=int, default=0)
parser.add_argument('--hidden_sizes_ode', nargs='+', type=int, default=[64, 64, 64])
# TIGON
parser.add_argument('--train_tigon', action='store_true')
parser.add_argument('--train_tigon_batch', type=int, default=64)
parser.add_argument('--train_tigon_iters', type=int, default=1_000)
parser.add_argument('--train_tigon_lr', type=float, default=3e-3)
parser.add_argument('--train_tigon_teacherforcing_its', type=int, default=0)
parser.add_argument('--train_tigon_ckpt', type=int, default=1000)
parser.add_argument('--train_tigon_reg', type=str, default='vf')
parser.add_argument('--train_tigon_resume', type=int, default=0)
parser.add_argument('--hidden_sizes_tigon', nargs='+', type=int, default=[64, 64, 64])
# other 
parser.add_argument('--print_iter', type=int, default=100)
parser.add_argument('--reach', type=float, default=5.)
parser.add_argument('--score_logsigma_min', type=float, default=-2)
parser.add_argument('--score_logsigma_max', type=float, default=0)
parser.add_argument('--score_logsigma_steps', type=float, default=5)
parser.add_argument('--sigma_anneal_iters', type=int_or_none, default=2000)
parser.add_argument('--alpha_wfr', type=float, default=0.1)
parser.add_argument('--reg_wfr', type=float, default=0.001)
parser.add_argument('--D', type=float, default=0.25)
parser.add_argument('--dt_ratio', type=int, default=2)
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

device = torch.device(args.device)
torch.set_default_dtype(torch.float32)

data = torch.load(args.data, weights_only = False)
dim = data['x'].shape[1]

# Score fitting 
X = [torch.tensor(data['x'][data['t_idx'] == i, :], device = device, dtype = torch.float32) for i in np.sort(np.unique(data['t_idx']))]
ts = torch.linspace(0, data['t_final'], len(X), device = device)

s = models.NCScoreFunc(d = dim, hidden_sizes = args.hidden_sizes_score, activation = torch.nn.ReLU, time_dependent = True).to(device)
sigmas = torch.linspace(args.score_logsigma_max, args.score_logsigma_min, args.score_logsigma_steps, device = device).exp()
if args.train_score:
    s_opt = optim.AdamW(s.parameters(), lr = args.train_score_lr)
    s_trace = train.train_denoising_score(s, s_opt, sigmas, {'X' : X, 't' : ts}, args.train_score_batch,
                                          options = {'iters' : args.train_score_iters,
                                                     'print_iter' : args.print_iter,
                                                     'checkpoint_iter' : args.train_score_ckpt,
                                                     'checkpoint_file' : f"params_NCScoreFunc_{args.suffix}",
                                                     'save_final' : True,
                                                     'save_file' : f"params_NCScoreFunc_{args.suffix}",
                                                     'outdir' : args.outdir},
                                          sample_batch_options = {'replacement' : True, })
    # Plot score trace
    plt.figure(figsize = (3, 3))
    plt.plot(s_trace); plt.yscale('log'); plt.xlabel("Iteration"); plt.ylabel("Loss")
    plt.title("Score fitting"); plt.savefig(os.path.join(args.figdir, f"trace_NCScoreFunc_{args.suffix}.pdf"))
else:
    s.load_state_dict(torch.load(os.path.join(args.outdir, f"params_NCScoreFunc_{args.suffix}_final.pt")))

samplesloss_options = {'loss' : 'sinkhorn', 'reach' : args.reach}
dt = ts[1]-ts[0]
s = s.to(device)
m_ratios = [x.shape[0] / X[0].shape[0] for x in X]

## other parameters
D = args.D
T = len(X)
odeint_options = {'method' : 'euler', 'options' : {'step_size' : data['t_final'] / (args.dt_ratio*T)}}

random.seed(args.seed)
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

time_dependent_map = {'aut' : False, 'nonaut' : True}

# UPFI
v_upfi = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : args.hidden_sizes_upfi, 'time_dependent' : time_dependent_map[args.train_upfi_v_type]},
                                       kwargs_g = {'hidden_sizes' : args.hidden_sizes_upfi_g, 'time_dependent' : time_dependent_map[args.train_upfi_g_type]}).to(device)

if args.train_upfi:
    opt_upfi = optim.AdamW([{'params' : v_upfi.v_net.parameters(), 'lr' : args.train_upfi_lr}, {'params' : v_upfi.g_net.parameters(), 'lr' : args.train_upfi_lr}])
    # sched_upfi = StepLR(opt_upfi, step_size=500, gamma=0.5)
    print("Training UPFI")
    _options={'iters' : args.train_upfi_iters,
              'print_iter' : args.print_iter,
              'checkpoint_iter' : args.train_upfi_ckpt,
              'checkpoint_file' : f"params_UPFI_ODEFlowGrowth_{args.suffix}",
              'save_final' : True,
              'save_file' : f"params_UPFI_ODEFlowGrowth_{args.suffix}",
              'anneal_sigma_iters' : args.sigma_anneal_iters,
              'reg_kind' : args.train_upfi_reg, 
              'outdir' : args.outdir,
              'teacher_forcing_iter' : args.train_upfi_teacherforcing_its}
    print(f"options = {_options}")
    trace_upfi = train.train_upfi(v_upfi, opt_upfi, s, sigmas,
                                  {'X' : X, 't' : ts, 'm_ratios' : torch.tensor(m_ratios).to(device), 'dt' : dt},
                                  {'D' : D, 'alpha_wfr' : args.alpha_wfr, 'reg_wfr' : args.reg_wfr},
                                  args.train_upfi_batch,
                                  options = _options,
                                  sample_batch_options = {'replacement' : True,  'add_noise' : False}, 
                                  odeint_options = odeint_options,
                                  samplesloss_options = samplesloss_options)
    plt.figure(figsize = (3, 3))
    plt.plot(trace_upfi); plt.yscale('log'); plt.xlabel("Iteration"); plt.ylabel("Loss")
    plt.title("UPFI"); plt.savefig(os.path.join(args.figdir, f"trace_UPFI_ODEFlowGrowth_{args.suffix}.pdf"))
else:
    pass

random.seed(args.seed)
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

# PFI 
v_pfi = models.VectorField(d = dim, hidden_sizes = args.hidden_sizes_pfi, time_dependent = time_dependent_map[args.train_pfi_v_type]).to(device)
if args.train_pfi:
    print("Training PFI")
    _options = {'iters' : args.train_pfi_iters,
                'print_iter' : args.print_iter,
                'checkpoint_iter' : args.train_pfi_ckpt,
                'checkpoint_file' : f"params_PFI_VectorField_{args.suffix}",
                'save_final' : True,
                'save_file' : f"params_PFI_VectorField_{args.suffix}",
                'anneal_sigma_iters' : args.sigma_anneal_iters,
                'reg_kind' : args.train_pfi_reg, 
                'outdir' : args.outdir,
                'teacher_forcing_iter' : args.train_pfi_teacherforcing_its}
    print(f"options = {_options}")
    opt_pfi = optim.AdamW(v_pfi.parameters(), lr = args.train_pfi_lr)
    trace_pfi = train.train_pfi(v_pfi, opt_pfi, s, sigmas,
                                  {'X' : X, 't' : ts, 'dt' : dt},
                                  {'D' : D, 'reg' : args.reg_wfr},
                                  args.train_pfi_batch,
                                  options=_options,
                                  sample_batch_options = {'replacement' : True, 'add_noise' : False}, 
                                  odeint_options = odeint_options)
    plt.figure(figsize = (3, 3))
    plt.plot(trace_pfi); plt.yscale('log'); plt.xlabel("Iteration"); plt.ylabel("Loss")
    plt.title("PFI"); plt.savefig(os.path.join(args.figdir, f"trace_PFI_VectorField_{args.suffix}.pdf"))
else:
    pass

random.seed(args.seed)
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

# ODE 
v_ode = models.ODEFlowGrowthCoupled(d = dim, hidden_sizes = args.hidden_sizes_ode, time_dependent = True).to(device)
if args.train_ode:
    print("Training ODE")
    _options={'iters' : args.train_ode_iters,
              'print_iter' : args.print_iter,
              'checkpoint_iter' : args.train_ode_ckpt,
              'checkpoint_file' : f"params_ODE_ODEFlowGrowthCoupled_{args.suffix}",
              'save_final' : True,
              'save_file' : f"params_ODE_ODEFlowGrowthCoupled_{args.suffix}",
              'outdir' : args.outdir,
              'teacher_forcing_iter' : args.train_ode_teacherforcing_its}
    print(f"options = {_options}")
    opt_ode = optim.AdamW(v_ode.parameters(), lr = args.train_ode_lr)
    trace_ode = train.train_ode(v_ode, opt_ode,
                                  {'X' : X, 't' : ts, 'm_ratios' : torch.tensor(m_ratios).to(device), 'dt' : dt},
                                  {'alpha_wfr' : args.alpha_wfr, 'reg_wfr' : args.reg_wfr},
                                  args.train_ode_batch,
                                  options = _options,
                                  sample_batch_options = {'replacement' : True, 'add_noise' : False},
                                  odeint_options = odeint_options,
                                  samplesloss_options = samplesloss_options)
    plt.figure(figsize = (3, 3))
    plt.plot(trace_ode); plt.yscale('log'); plt.xlabel("Iteration"); plt.ylabel("Loss")
    plt.title("ODE"); plt.savefig(os.path.join(args.figdir, f"trace_ODE_ODEFlowGrowthCoupled_{args.suffix}.pdf"))
else:
    pass

# TIGON
v_tigon = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : args.hidden_sizes_tigon, 'time_dependent' : True},
                                       kwargs_g = {'hidden_sizes' : args.hidden_sizes_tigon, 'time_dependent' : True}).to(device)
if args.train_tigon:
    opt_tigon = optim.AdamW([{'params' : v_tigon.v_net.parameters(), 'lr' : args.train_tigon_lr}, {'params' : v_tigon.g_net.parameters(), 'lr' : args.train_tigon_lr}])
    # sched_tigon = StepLR(opt_tigon, step_size=500, gamma=0.5)
    print("Training TIGON")
    _options={'iters' : args.train_tigon_iters,
              'print_iter' : args.print_iter,
              'checkpoint_iter' : args.train_tigon_ckpt,
              'checkpoint_file' : f"params_TIGON_ODEFlowGrowth_{args.suffix}",
              'save_final' : True,
              'save_file' : f"params_TIGON_ODEFlowGrowth_{args.suffix}",
              'reg_kind' : args.train_tigon_reg, 
              'outdir' : args.outdir,
              'teacher_forcing_iter' : args.train_tigon_teacherforcing_its}
    print(f"options = {_options}")
    trace_tigon = train.train_tigon(v_tigon, opt_tigon, 
                                  {'X' : X, 't' : ts, 'm_ratios' : torch.tensor(m_ratios).to(device), 'dt' : dt},
                                  {'alpha_wfr' : args.alpha_wfr, 'reg_wfr' : args.reg_wfr},
                                  args.train_tigon_batch,
                                  options = _options,
                                  sample_batch_options = {'replacement' : True,  'add_noise' : False}, 
                                  odeint_options = odeint_options,
                                  samplesloss_options = samplesloss_options)
    plt.figure(figsize = (3, 3))
    plt.plot(trace_tigon); plt.yscale('log'); plt.xlabel("Iteration"); plt.ylabel("Loss")
    plt.title("TIGON"); plt.savefig(os.path.join(args.figdir, f"trace_TIGON_ODEFlowGrowth_{args.suffix}.pdf"))
else:
    pass
