import numpy as np
import torch
import argparse 
import sys
from tqdm import tqdm
import matplotlib.pyplot as plt
import math
import random
import os
import torchsde
from torchdiffeq import odeint
import evals
import pandas as pd

sys.path.append("../../src/")
import models, train, utils

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default="data.pkl")
parser.add_argument('--T', type=int, default=5)
parser.add_argument('--suffix', type=str, default = '')
parser.add_argument('--outdir', type=str, default = ".")
parser.add_argument('--figdir', type=str, default = ".")
parser.add_argument('--evaldir', type=str, default = ".")
parser.add_argument('--device', type=str, default = "cuda:0")
parser.add_argument('--hidden_sizes_score', nargs='+', type=int, default=[64, 64, 64])
# UPFI
parser.add_argument('--train_upfi_g_type', type=str, default='aut')
parser.add_argument('--train_upfi_v_type', type=str, default='aut')
parser.add_argument('--hidden_sizes_upfi', nargs='+', type = int, default = [64, 64, 64])
parser.add_argument('--hidden_sizes_upfi_g', nargs='+', type = int, default = [64, ])
# PFI 
parser.add_argument('--train_pfi_v_type', type=str, default='nonaut')
parser.add_argument('--hidden_sizes_pfi', nargs='+', type=int, default=[64, 64, 64])
# ODE
parser.add_argument('--hidden_sizes_ode', nargs='+', type=int, default=[64, 64, 64])
# TIGON
parser.add_argument('--hidden_sizes_tigon', nargs='+', type=int, default=[64, 64, 64])
# DeepRUOT 
parser.add_argument('--deepruot_suffix', type=str, default='')
# other 
parser.add_argument('--score_logsigma_min', type=float, default=-2)
parser.add_argument('--score_logsigma_max', type=float, default=0)
parser.add_argument('--score_logsigma_steps', type=float, default=5)
parser.add_argument('--D', type=float, default=0.25)
parser.add_argument('--dt_ratio', type=int, default=2)
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

device = torch.device(args.device)
torch.set_default_dtype(torch.float32)

data = torch.load(args.data, weights_only = False)
dim = data['x'].shape[1]
time_dependent_map = {'aut' : False, 'nonaut' : True}

X = [torch.tensor(data['x'][data['t_idx'] == i, :], device = device, dtype = torch.float32) for i in np.sort(np.unique(data['t_idx']))]
ts = torch.linspace(0, data['t_final'], len(X), device = device)

D = args.D
T = len(X)
odeint_options = {'method' : 'euler', 'options' : {'step_size' : data['t_final'] / (args.dt_ratio*T)}}

# Load score 
s = models.NCScoreFunc(d = dim, hidden_sizes = args.hidden_sizes_score, activation = torch.nn.ReLU, time_dependent = True).to(device)
s.load_state_dict(torch.load(os.path.join(args.outdir, f"params_NCScoreFunc_{args.suffix}_final.pt")))
sigmas = torch.linspace(args.score_logsigma_max, args.score_logsigma_min, args.score_logsigma_steps, device = device).exp()
# Load UPFI
v_upfi = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : args.hidden_sizes_upfi, 'time_dependent' : time_dependent_map[args.train_upfi_v_type]},
                                       kwargs_g = {'hidden_sizes' : args.hidden_sizes_upfi_g, 'time_dependent' : time_dependent_map[args.train_upfi_g_type]}).to(device)
v_upfi.load_state_dict(torch.load(os.path.join(args.outdir, f"params_UPFI_ODEFlowGrowth_{args.suffix}_final.pt")))
# Load PFI
v_pfi = models.VectorField(d = dim, hidden_sizes = args.hidden_sizes_pfi, time_dependent = time_dependent_map[args.train_pfi_v_type]).to(device)
v_pfi.load_state_dict(torch.load(os.path.join(args.outdir, f"params_PFI_VectorField_{args.suffix}_final.pt")))
# Load ODE 
v_ode = models.ODEFlowGrowthCoupled(d = dim, hidden_sizes = args.hidden_sizes_ode, time_dependent = True).to(device)
v_ode.load_state_dict(torch.load(os.path.join(args.outdir, f"params_ODE_ODEFlowGrowthCoupled_{args.suffix}_final.pt")))
# Load TIGON
v_tigon = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : args.hidden_sizes_tigon, 'time_dependent' : True},
                                       kwargs_g = {'hidden_sizes' : args.hidden_sizes_tigon, 'time_dependent' : True}).to(device)
v_tigon.load_state_dict(torch.load(os.path.join(args.outdir, f"params_TIGON_ODEFlowGrowth_{args.suffix}_final.pt")))

# Load DeepRUOT
import DeepRUOT.models
f_net = DeepRUOT.models.FNet(in_out_dim=dim, hidden_dim=128, n_hiddens=3, activation='leakyrelu').to(device)
sf2m_score_model=DeepRUOT.models.scoreNet2(in_out_dim=dim, hidden_dim=128,  activation='leakyrelu').float().to(device)
f_net.load_state_dict(torch.load(f"deepRUOT/model_result_{args.deepruot_suffix}"))
sf2m_score_model.load_state_dict(torch.load(f"deepRUOT/score_model_result_{args.deepruot_suffix}"))

# Marginal interpolation via SDE integration 
v_pfi.to(device); v_upfi.to(device); v_ode.to(device); v_tigon.to(device)
x0 = X[0]
x0_mass = utils.pad_zeros_upfi(x0)
sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)
sde_upfi = models.SDE(lambda t, x: v_upfi(t, x), sigma = torch.cat([torch.tensor([0, ]), torch.full((dim, ), D**0.5)]).to(device))
with torch.no_grad():
    xs_t_pfi = torchsde.sdeint(sde_pfi, x0, ts, method = "euler").cpu()
    xs_t_upfi = torchsde.sdeint(sde_upfi, x0_mass, ts, method = "euler").cpu()
    xs_t_tigon = odeint(v_tigon, x0_mass, ts, **odeint_options).cpu()
    xs_t_ode = odeint(v_ode, x0_mass, ts, **odeint_options).cpu()

class _SDE(torch.nn.Module):
    noise_type = "diagonal"
    sde_type = "ito"
    def __init__(self, ode_drift, growth, score, sigma=1.0):
        super().__init__()
        self.drift = ode_drift
        self.growth = growth
        self.score = score
        self.sigma = sigma
    def f(self, t, x):
        y = x[..., 1:]
        drift=self.drift(t,y)
        growth=self.growth(t,y)
        num = y.shape[0]
        t = t.expand(num, 1) 
        return torch.hstack([growth, drift+self.score.compute_gradient(t,y)])
    def g(self, t, x):
        y = x[..., 1:]
        return utils.pad_zeros_upfi(torch.ones_like(y))*self.sigma 
sde_deepruot = _SDE(f_net.v_net, f_net.g_net, sf2m_score_model, sigma=(D / (T-1))**0.5)
xs_t_deepruot = torchsde.sdeint(
        sde_deepruot,
        x0_mass.to(device),
        dt=0.01*(T-1),
        ts=ts.to(device)*(T-1),
    ).detach().cpu()

df_err_paths = pd.DataFrame(pd.Series({ 
    "UPFI" : evals.energy_distance_paths(xs_t_upfi[..., 1:].permute((1, 0, 2)).numpy(), data['x_paths']), 
    "PFI" : evals.energy_distance_paths(xs_t_pfi.permute((1, 0, 2)).numpy(), data['x_paths']),
    "ODE" : evals.energy_distance_paths(xs_t_ode[..., 1:].permute((1, 0, 2)).numpy(), data['x_paths']),
    "TIGON" : evals.energy_distance_paths(xs_t_tigon[..., 1:].permute((1, 0, 2)).numpy(), data['x_paths']),
    "DeepRUOT" : evals.energy_distance_paths(xs_t_deepruot[..., 1:].permute((1, 0, 2)).numpy(), data['x_paths'])
}))
df_err_paths.to_csv(os.path.join(args.evaldir, f"df_err_paths_{args.suffix}.csv"))

v_pfi.to(device); v_upfi.to(device); v_ode.to(device); v_tigon.to(device)
with torch.no_grad():
    # Get inferred vector field
    vf_pfi = torch.vstack([v_pfi(t, x) for (t, x) in zip(ts, X)]).cpu()
    vf_deepruot = torch.vstack([(T-1)*f_net.v_net(t*(T-1), x) for (t, x) in zip(ts, X)]).cpu()
    vf_upfi = torch.vstack([v_upfi.v_net(t, x) for (t, x) in zip(ts, X)]).cpu()
    vf_tigon = torch.vstack([v_tigon.v_net(t, x) for (t, x) in zip(ts, X)]).cpu()
    vf_ode = torch.vstack([v_ode.dF(t, x) for (t, x) in zip(ts, X)]).cpu()
    # Get inferred growth rate
    g_upfi = torch.vstack([v_upfi.g_net(t, x) for (t, x) in zip(ts, X)]).cpu()
    g_tigon = torch.vstack([v_tigon.g_net(t, x) for (t, x) in zip(ts, X)]).cpu()
    g_ode = torch.vstack([v_ode.F_net(t, x) for (t, x) in zip(ts, X)]).cpu()
    g_deepruot = torch.vstack([(T-1)*f_net.g_net(t*(T-1), x) for (t, x) in zip(ts, X)]).cpu()

df_vectorfield = pd.DataFrame({"Psi" : data['potential_true'],})
vf_gt = torch.tensor(data['v_true'], dtype = torch.float32)
for _dist, what in [(lambda x, y: utils.cos_dist(x, y), "cos"), (lambda u, v: (u-v).norm(2, 1), "l2")]:
    df_vectorfield.loc[:, f"dv_PFI_{what}"] = _dist(vf_gt, vf_pfi).numpy()
    df_vectorfield.loc[:, f"dv_DeepRUOT_{what}"] = _dist(vf_gt, vf_deepruot).numpy()
    df_vectorfield.loc[:, f"dv_UPFI_{what}"]= _dist(vf_gt, vf_upfi).numpy()
    df_vectorfield.loc[:, f"dv_TIGON_{what}"]= _dist(vf_gt, vf_tigon).numpy()
    df_vectorfield.loc[:, f"dv_ODE_{what}"] = _dist(vf_gt, vf_ode).numpy()
df_vectorfield.to_csv(os.path.join(args.evaldir, f"df_vectorfield_{args.suffix}.csv"))

# Fate probabilities
import sklearn as sk
sde_deepruot.to(device)
_centroids = torch.tensor(data['centroids'], dtype = torch.float32)
probs_upfi, probs_pfi, probs_ode, probs_deepruot, probs_tigon = [], [], [], [], []
for i in tqdm(range(T)):
    probs_upfi.append(evals.get_centroid_probs(X[i],
                                               lambda x: torchsde.sdeint(sde_upfi, utils.pad_zeros_upfi(x), ts[[i, -1]], method = "euler")[-1, ..., 1:].cpu() if i < T-1 else x.cpu(), 
                                               _centroids, 
                                              ))
    probs_pfi.append(evals.get_centroid_probs(X[i], 
                                             lambda x: torchsde.sdeint(sde_pfi, x, ts[[i, -1]], method = "euler")[-1, ...].cpu() if i < T-1 else x.cpu(), 
                                             _centroids))
    probs_deepruot.append(evals.get_centroid_probs(X[i], 
                                               lambda x: torchsde.sdeint(sde_deepruot, utils.pad_zeros_upfi(x), ts[[i, -1]]*(T-1), dt = 0.01*(T-1), method = "euler")[-1, ..., 1:].cpu() if i < T-1 else x.cpu(), 
                                            _centroids))
    probs_ode.append(evals.get_centroid_probs(X[i], 
                                               lambda x: odeint(v_ode, utils.pad_zeros_upfi(x), ts[[i, -1]], **odeint_options)[-1, ..., 1:].cpu() if i < T-1 else x.cpu(), 
                                            _centroids, n_sample = 1))
    probs_tigon.append(evals.get_centroid_probs(X[i], 
                                               lambda x: odeint(v_tigon, utils.pad_zeros_upfi(x), ts[[i, -1]], **odeint_options)[-1, ..., 1:].cpu() if i < T-1 else x.cpu(), 
                                            _centroids, n_sample = 1))

df_probs_upfi = pd.DataFrame(torch.vstack(probs_upfi), columns = [f"{i}" for i in range(_centroids.shape[0])])
df_probs_pfi = pd.DataFrame(torch.vstack(probs_pfi), columns = [f"{i}" for i in range(_centroids.shape[0])])
df_probs_deepruot = pd.DataFrame(torch.vstack(probs_deepruot), columns = [f"{i}" for i in range(_centroids.shape[0])])
df_probs_ode = pd.DataFrame(torch.vstack(probs_ode), columns = [f"{i}" for i in range(_centroids.shape[0])])
df_probs_tigon = pd.DataFrame(torch.vstack(probs_tigon), columns = [f"{i}" for i in range(_centroids.shape[0])])

df_probs = pd.concat([df_probs_upfi, df_probs_pfi, df_probs_deepruot, df_probs_ode, df_probs_tigon], axis = 1, keys = ["UPFI", "PFI", "DeepRUOT", "ODE", "TIGON"])
df_probs.to_csv(os.path.join(args.evaldir, f"df_probs_{args.suffix}.csv"))

import scipy as sp
pd.DataFrame([{"UPFI" : _f(df_probs_upfi.iloc[:, 0], data['probs'][:, 0]).statistic,
    "PFI" : _f(df_probs_pfi.iloc[:, 0], data['probs'][:, 0]).statistic,
    "DeepRUOT" : _f(df_probs_deepruot.iloc[:, 0], data['probs'][:, 0]).statistic,
    "ODE" : _f(df_probs_ode.iloc[:, 0], data['probs'][:, 0]).statistic,
    "TIGON" : _f(df_probs_tigon.iloc[:, 0], data['probs'][:, 0]).statistic,
    "what" : _s
} for (_f, _s) in zip([sp.stats.pearsonr, sp.stats.spearmanr, sp.stats.kendalltau], ['pearson', 'spearman', 'kendall'])]).to_csv(os.path.join(args.evaldir, f"df_fate_corr_{args.suffix}.csv"))
