import torch
import sys
sys.path.append("../../src")
import models, evals, utils
from torch import optim
import torchdiffeq
from torchdiffeq import odeint
import torchsde
import geomloss
from tqdm import tqdm
device = torch.device('cuda:0')

import sklearn as sk
import sklearn.decomposition
import sklearn.preprocessing
import numpy as np
import matplotlib.pyplot as plt
import importlib

import scanpy as sc
import anndata as ad
import pandas as pd
import scipy as sp

np.random.seed(0)
torch.manual_seed(0)
print("Loading data")
data = torch.load("data_pca.pkl")

device = torch.device('cuda:0')
seed = int(sys.argv[1])

T = 3
D = 0.25
dim = 10
ts = torch.linspace(0, 1, T)
X = [data['x'][data['t_idx'] == i].to(device) for i in range(T)]

print("Loading score model")
s = models.NCScoreFunc(d = dim, hidden_sizes = [128, 128, 128], activation = torch.nn.ReLU, time_dependent = True).to(device)
s.load_state_dict(torch.load(f"weights/params_NCScoreFunc_default_additive_pcadim_{dim}_seed_{seed}_final.pt"))
sigmas = torch.linspace(0, -2, 5, device = device).exp()

print("Loading dyn models")
hidden_sizes = [128, 128, 128]
v_upfi = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : hidden_sizes, 'time_dependent' : False}, 
                                       kwargs_g = {'hidden_sizes' : hidden_sizes, 'time_dependent' : False}).to(device)
v_upfi.load_state_dict(torch.load(f'weights/params_UPFI_ODEFlowGrowth_default_additive_pcadim_{dim}_seed_{seed}_final.pt'))
v_pfi = models.VectorField(d = dim, hidden_sizes = hidden_sizes, time_dependent = True).to(device)
v_pfi.load_state_dict(torch.load(f'weights/params_PFI_VectorField_default_additive_pcadim_{dim}_seed_{seed}_final.pt'))
v_ode = models.ODEFlowGrowthCoupled(d = dim, hidden_sizes = hidden_sizes, time_dependent = True).to(device)
v_ode.load_state_dict(torch.load(f'weights/params_ODE_ODEFlowGrowthCoupled_default_additive_pcadim_{dim}_seed_{seed}_final.pt'))
v_tigon = models.ODEFlowGrowth(d = dim, kwargs_v = {'hidden_sizes' : hidden_sizes, 'time_dependent' : True}, 
                                       kwargs_g = {'hidden_sizes' : hidden_sizes, 'time_dependent' : True}).to(device)
v_tigon.load_state_dict(torch.load(f'weights/params_TIGON_ODEFlowGrowth_default_additive_pcadim_{dim}_seed_{seed}_final.pt'))

## Get vector field reconstruction
print("Evaluating learned vector fields")
_v_upfi = v_upfi.v_net(None, data['x'].to(device)).detach().cpu()
_v_pfi = torch.vstack([v_pfi(ts[i], data['x'][data['t_idx'] == i, :].to(device)) for i in range(T)]).detach().cpu()
_v_tigon = torch.vstack([v_tigon.v_net(t, x) for (t, x) in zip(ts, X)]).detach().cpu()
_v_ode = torch.vstack([v_ode.dF(t, x) for (t, x) in zip(ts, X)]).detach().cpu()
_dfs = []
for c in ["Undifferentiated", "Monocyte", "Neutrophil"]:
    for t in range(T):
        idx = (data["celltype"] == c) & (data['t_idx'] == t)
        _dfs.append(pd.DataFrame(
            {"UPFI" : utils.cos_dist(data['v'][idx], _v_upfi[idx]),
             "PFI" : utils.cos_dist(data['v'][idx], _v_pfi[idx]), 
             "ODE" : utils.cos_dist(data['v'][idx], _v_ode[idx]), 
             "TIGON" : utils.cos_dist(data['v'][idx], _v_tigon[idx]), 
         "celltype" : c,
         "t" : t}))
_df = pd.concat(_dfs).melt(id_vars = ['t', 'celltype'], value_vars=['UPFI', 'PFI', "ODE", "TIGON"])
_df_mean = _df.groupby(['t', 'celltype', 'variable']).mean()
_df_mean.to_csv(f"df_rnavelo_cos_mean_seed_{seed}.csv")

## Get fate probabilities
print("Calculating fate probabilities")
odeint_options = {'method' : 'euler', 'options' : {'step_size' : data['t_final'] / (5*T)}}
m_ratios = torch.tensor([x.shape[0] / X[0].shape[0] for x in X]).float()
_ts = torch.linspace(0, 1, 25)
x0_mass = utils.sample_batch_upfi(X, m_ratios.to(device), 1024)[0]
x0 = x0_mass[..., 1:]
sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)
sde_upfi = models.SDE(lambda t, x: v_upfi(t, x), sigma = torch.cat([torch.tensor([0, ]), torch.full((dim, ), D**0.5)]).to(device))

_centroids = torch.vstack([data['x'][(data['celltype'] == "Neutrophil") & (data['t_idx'] == 2), :].mean(0),
    data['x'][(data['celltype'] == "Monocyte") & (data['t_idx'] == 2), :].mean(0)])

bs = 250
X_pca = X[0]
X_pca = [X_pca[i*bs:min((i+1)*bs, X_pca.shape[0])] for i in range(X_pca.shape[0] // bs + 1)]
probs_upfi, probs_pfi, probs_ode, probs_tigon = [], [], [], []
for i in tqdm(range(len(X_pca))):
    probs_upfi.append(evals.get_centroid_probs(X_pca[i],
                                               lambda x: torchsde.sdeint(sde_upfi, utils.pad_zeros_upfi(x), ts, method = "euler")[-1, ..., 1:], 
                                               _centroids.to(device), n_sample = 25
                                              ))
    probs_pfi.append(evals.get_centroid_probs(X_pca[i], 
                                             lambda x: torchsde.sdeint(sde_pfi, x, ts, method = "euler")[-1, ...], 
                                             _centroids.to(device), n_sample = 25))
    probs_ode.append(evals.get_centroid_probs(X_pca[i], 
                                             lambda x: odeint(v_ode.dF, x, ts, **odeint_options)[-1, ...], 
                                             _centroids.to(device), n_sample = 1))
    probs_tigon.append(evals.get_centroid_probs(X_pca[i], 
                                             lambda x: odeint(v_tigon.v_net, x, ts, **odeint_options)[-1, ...], 
                                             _centroids.to(device), n_sample = 1))
probs_upfi = torch.vstack([x for x in probs_upfi])
probs_pfi = torch.vstack([x for x in probs_pfi])
probs_ode = torch.vstack([x for x in probs_ode])
probs_tigon = torch.vstack([x for x in probs_tigon])

# Calculate lineage tracing fates 
df = pd.DataFrame(data['cloneid'])
df['celltype'] = data['celltype']
df['t'] = data['t_idx']
df_subset = df.loc[(df.t > 0) & (df.cloneid > -1), :].copy()

df_subset.loc[:, "undiff"] = 0; df_subset.loc[:, "mon"] = 0; df_subset.loc[:, "neu"] = 0
# df_subset.loc[df_subset.celltype == "Undifferentiated", "undiff"] = 1
df_subset.loc[df_subset.celltype == "Monocyte", "mon"] = 1
df_subset.loc[df_subset.celltype == "Neutrophil", "neu"] = 1
df_subset = df_subset.iloc[(df_subset.iloc[:, 3:].values.sum(1) > 0), :]
df_clone = df_subset.loc[:, ['mon', 'neu', 'undiff', 'cloneid'],].groupby('cloneid').sum()
df_clone = df_clone / df_clone.values.sum(-1)[:, None]
df_clone.columns = ['pmon', 'pneu', 'pundiff']

df_spring = pd.DataFrame(data['x_spring'], index = data['id'], columns = ["SPRINGx", "SPRINGy"])
df_subset = pd.concat([df_subset, df_clone.loc[df_subset.cloneid].reset_index().set_index(df_subset.index), df_spring.loc[df_subset.index]], axis = 1)

df_subset_ = df.loc[(df.t == 0) & (df.cloneid > -1), :].copy().reset_index().set_index('cloneid')
df_subset_ = df_subset_[df_subset_.index.isin(df_clone.index)]
df_subset_ = pd.concat([df_subset_, df_clone.loc[df_subset_.index]], axis = 1).reset_index()
df_subset_ = df_subset_.set_index('index')
df_subset_ = pd.concat([df_subset_, df_spring.loc[df_subset_.index]], axis = 1)
_idx=pd.Series(data['id'])[data['t_idx'] == 0]
_col = ["neu", "mon"]

print("Saving fate probabilities")

torch.save({'probs_upfi' : probs_upfi,
 'probs_pfi' : probs_pfi,
 'probs_ode' : probs_ode,
 'probs_tigon' : probs_tigon,
 'id' : data['id']}, f"fate_probs_seed_{seed}.pkl")

df_probs = pd.concat([pd.DataFrame(probs_pfi.cpu(), index = _idx, columns = _col), 
                      pd.DataFrame(probs_upfi.cpu(), index = _idx, columns = _col), 
                      pd.DataFrame(probs_ode.cpu(), index = _idx, columns = _col),
                      pd.DataFrame(probs_tigon.cpu(), index = _idx, columns = _col)
                     ],
                     keys = ["PFI", "UPFI", "ODE", "TIGON"], axis = 1)
df_probs = df_probs.loc[df_subset_.index, :]
df = pd.concat([df_probs, df_subset_], axis = 1)
df.columns = ['_'.join(tup).rstrip('_') if isinstance(tup, tuple) else tup for tup in df.columns.values]
df.to_csv(f"df_fate_seed_{seed}.csv")

pd.DataFrame([{"PFI" : _f(df.loc[:, "PFI_neu"], df.pneu).statistic,
    "UPFI" : _f(df.loc[:, "UPFI_neu"], df.pneu).statistic,
    "ODE" : _f(df.loc[:, "ODE_neu"], df.pneu).statistic,
    "TIGON" : _f(df.loc[:, "TIGON_neu"], df.pneu).statistic,
    "what" : _s
} for (_f, _s) in zip([sp.stats.pearsonr, sp.stats.spearmanr, sp.stats.kendalltau], ['pearson', 'spearman', 'kendall'])]).to_csv(f"df_fate_pearsonr_seed_{seed}.csv")

