import torch
import numpy as np
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import pandas as pd
import os
import sklearn as sk
from sklearn import preprocessing, decomposition
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sb
import sys
sys.path.append("../../src")
import models, utils, evals
from tqdm import tqdm
import scipy as sp

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float32)
c = 0.5
beta = 5.5
seed = int(sys.argv[1])
data = torch.load(f"sim_HSC_N_500_T_10_c_{c}_beta_{beta}.pkl", weights_only = False)

T = data['t_idx'].max()+1
ts = torch.linspace(0, data['t_final'], T)
dim = data['x'].shape[1]
X = [torch.tensor(data['x'][data['t_idx'] == i, ...], device = device, dtype = torch.float32) for i in range(T)]

print("Loading score model")
sigmas = torch.linspace(0, -2, 5, device = device).exp()
s = models.NCScoreFunc(d = dim, hidden_sizes = [128, 128, 128], activation = torch.nn.ReLU, time_dependent = True).to(device)
s.load_state_dict(torch.load(f"weights/params_NCScoreFunc_default_c_{c}_seed_{seed}_final.pt"))

print("Loading dyn models")
D = 0.5**2
ts = torch.linspace(0, data['t_final'], len(X)).to(device)
m_ratios = torch.tensor([x.shape[0] / X[0].shape[0] for x in X])
hidden_sizes = [128, 128, 128]
hidden_sizes_mult = [128, 128, 128]

v_upfi_mult = models.MultiplicativeNoiseFlowGrowth(dim, lambda t, x, sigma : s(t, x, sigma), D, sigmas[-1], 
                                                  kwargs_u = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes_mult}, 
                                                  kwargs_v = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes_mult},
                                                  kwargs_g = {'time_dependent' : False, 'hidden_sizes' : hidden_sizes_mult[:1]}).to(device)
v_upfi_mult.load_state_dict(torch.load(f'weights/params_MULT_UPFI_ODEFlowGrowth_default_c_{c}_seed_{seed}_final.pt'))

v_pfi_mult = models.MultiplicativeNoiseFlow(dim, lambda t, x, sigma : s(t, x, sigma), D, sigmas[-1], 
                                                  kwargs_u = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes_mult}, 
                                                  kwargs_v = {'output_activation' : torch.nn.Softplus, 'time_dependent' : False, 'hidden_sizes' : hidden_sizes_mult}).to(device)
v_pfi_mult.load_state_dict(torch.load(f'weights/params_MULT_PFI_VectorField_default_c_{c}_seed_{seed}_final.pt'))

v_upfi = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : hidden_sizes, 'time_dependent' : False}, 
                                       kwargs_g = {'hidden_sizes' : hidden_sizes[:1], 'time_dependent' : False}).to(device)
v_upfi.load_state_dict(torch.load(f'weights/params_UPFI_ODEFlowGrowth_default_c_{c}_seed_{seed}_final.pt'))

v_pfi = models.VectorField(d = dim, hidden_sizes = hidden_sizes, time_dependent = True).to(device)
v_pfi.load_state_dict(torch.load(f'weights/params_PFI_VectorField_default_c_{c}_seed_{seed}_final.pt'))

import torchsde
import seaborn as sb

print("Vector field evaluation")
f_true = torch.tensor(data['f'], dtype = torch.float32)
g_true = torch.tensor(data['g'], dtype = torch.float32)
_x = torch.tensor(data['x'], dtype = torch.float32).to(device)
with torch.no_grad():
    u_est_upfi_mult = v_upfi_mult.u.net(_x).cpu()
    v_est_upfi_mult = v_upfi_mult.v.net(_x).cpu()
    u_est_pfi_mult = v_pfi_mult.u.net(_x).cpu()
    v_est_pfi_mult = v_pfi_mult.v.net(_x).cpu()
    vf_pfi = torch.vstack([v_pfi(ts[i], torch.tensor(data['x'][data['t_idx'] == i, :], dtype = torch.float32).to(device)) for i in range(T)]).cpu()
    vf_upfi = torch.vstack([v_upfi.v_net(ts[i], torch.tensor(data['x'][data['t_idx'] == i, :], dtype = torch.float32).to(device)) for i in range(T)]).cpu()
pd.DataFrame({"PFI_additive" : utils.l2_dist(vf_pfi, f_true).mean().item(),
              "PFI_mult" : utils.l2_dist(u_est_pfi_mult - v_est_pfi_mult, f_true).mean().item(),
              "UPFI_additive" : utils.l2_dist(vf_upfi, f_true).mean().item(),
              "UPFI_mult" : utils.l2_dist(u_est_upfi_mult - v_est_upfi_mult, f_true).mean().item()
              }, index = [0]).to_csv(os.path.join('evals', f"df_vf_l2_dist_seed_{seed}.csv"))

# Sample paths
print("Sampling paths")
x0_mass = utils.sample_batch_upfi(X, m_ratios.to(device), batch_size = 1024, replacement=True, add_noise=True)[0]
x0 = x0_mass[..., 1:]
sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)
sde_upfi = models.SDE(lambda t, x: v_upfi(t, x), sigma = torch.cat([torch.tensor([0, ]), torch.full((dim, ), D**0.5)]).to(device))
sde_pfi_mult = models.MultiplicativeNoiseSDE(v_pfi_mult.u, v_pfi_mult.v, sigma = D**0.5)
sde_upfi_mult = models.MultiplicativeNoiseSDE(v_upfi_mult.u, v_upfi_mult.v, sigma = D**0.5)

with torch.no_grad():
    xs_t_pfi = torchsde.sdeint(sde_pfi, x0, ts, method = "euler").cpu()
    xs_t_upfi = torchsde.sdeint(sde_upfi, x0_mass, ts, method = "euler").cpu()[..., 1:]
    xs_t_pfi_mult = torchsde.sdeint(sde_pfi_mult, x0, ts, method = "euler").cpu()
    xs_t_upfi_mult = torchsde.sdeint(sde_upfi_mult, x0, ts, method = "euler").cpu()
xs_t_pfi_ = xs_t_pfi.reshape(-1, xs_t_pfi.shape[-1]).cpu()
xs_t_upfi_ = xs_t_upfi.reshape(-1, xs_t_upfi.shape[-1]).cpu()
xs_t_pfi_mult_ = xs_t_pfi_mult.reshape(-1, xs_t_pfi_mult.shape[-1]).cpu()
xs_t_upfi_mult_ = xs_t_upfi_mult.reshape(-1, xs_t_upfi_mult.shape[-1]).cpu()

pd.DataFrame({'UPFI_additive' : evals.energy_distance_paths(xs_t_upfi.permute((1, 0, 2)).numpy(), data['x_paths']),
            'UPFI_mult' : evals.energy_distance_paths(xs_t_upfi_mult.permute((1, 0, 2)).numpy(), data['x_paths']),
             'PFI_additive' : evals.energy_distance_paths(xs_t_pfi.permute((1, 0, 2)).numpy(), data['x_paths']), 
             'PFI_mult' : evals.energy_distance_paths(xs_t_pfi_mult.permute((1, 0, 2)).numpy(), data['x_paths']), 
              }, index = [0, ]).to_csv(os.path.join('evals', f"df_energy_distance_paths_seed_{seed}.csv"))

# Fate probabilities
print("Computing fate probabilities")
probs_upfi, probs_pfi, probs_upfi_mult, probs_pfi_mult, probs_ode, probs_deepruot = [], [], [], [], [], []
for i in tqdm(range(T)):
    probs_upfi.append(evals.get_centroid_probs(X[i],
                                               lambda x: torchsde.sdeint(sde_upfi, utils.pad_zeros_upfi(x), ts[[i, -1]], method = "euler")[-1, ..., 1:].cpu() if i < T-1 else x.cpu(), 
                                               data['centroids'], 
                                              ))
    probs_upfi_mult.append(evals.get_centroid_probs(X[i],
                                               lambda x: torchsde.sdeint(sde_upfi_mult, x, ts[[i, -1]], method = "euler")[-1, ...].cpu() if i < T-1 else x.cpu(), 
                                               data['centroids'], 
                                              ))
    probs_pfi.append(evals.get_centroid_probs(X[i], 
                                             lambda x: torchsde.sdeint(sde_pfi, x, ts[[i, -1]], method = "euler")[-1, ...].cpu() if i < T-1 else x.cpu(), 
                                             data['centroids']))
    probs_pfi_mult.append(evals.get_centroid_probs(X[i], 
                                             lambda x: torchsde.sdeint(sde_pfi_mult, x, ts[[i, -1]], method = "euler")[-1, ...].cpu() if i < T-1 else x.cpu(), 
                                             data['centroids']))
probs_upfi = torch.vstack(probs_upfi)
probs_pfi = torch.vstack(probs_pfi)
probs_upfi_mult = torch.vstack(probs_upfi_mult)
probs_pfi_mult = torch.vstack(probs_pfi_mult)

print("Saving fate probabilities")
torch.save({'probs_upfi' : probs_upfi,
    'probs_upfi_mult' : probs_upfi_mult,
    'probs_pfi' : probs_pfi,
    'probs_pfi_mult' : probs_pfi_mult
            }, os.path.join('evals', f"fate_probs_seed_{seed}.pkl"))

if data['probs'].shape[1] == 2: # use correlation measure
    pd.DataFrame([{"UPFI_additive" : _f(probs_upfi[:, 0], data['probs'][:, 0]).statistic,
                "UPFI_mult" : _f(probs_upfi_mult[:, 0], data['probs'][:, 0]).statistic,
                "PFI_additive" : _f(probs_pfi[:, 0], data['probs'][:, 0]).statistic,
                "PFI_mult" : _f(probs_pfi_mult[:, 0], data['probs'][:, 0]).statistic,
                "what" : _s
                } for (_f, _s) in zip([sp.stats.pearsonr, sp.stats.spearmanr, sp.stats.kendalltau], ['pearson', 'spearman', 'kendall'])]).to_csv(os.path.join('evals', f"df_fate_pearsonr_seed_{seed}.csv"))
else:
    pd.DataFrame([{"UPFI_additive" : (probs_upfi-data['probs']).abs().sum(-1).mean().item(),
                "UPFI_mult" : (probs_upfi_mult-data['probs']).abs().sum(-1).mean().item(),
                "PFI_additive" : (probs_pfi-data['probs']).abs().sum(-1).mean().item(),
                "PFI_mult" : (probs_pfi_mult-data['probs']).abs().sum(-1).mean().item(),
                "what" : 'TV'
                }]).to_csv(os.path.join('evals', f"df_fate_tv_seed_{seed}.csv"))
