import torch
from torch import nn, optim
import autograd
import autograd.numpy as np
import torchdiffeq
from torchdiffeq import odeint
import geomloss
from tqdm import tqdm
import importlib
import math
import torchsde
import sys
sys.path.append("../../src/")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float32)

seed=int(sys.argv[1])
torch.manual_seed(seed)
np.random.seed(seed)

num_iter=5_000
D = 0.25
beta = 5.5
c = 0.5
T=10
N=500
suffix=f"default_seed_{seed}"
alpha_wfr = 1
reg_wfr = 0.001

data = torch.load(f"sim_HSC_N_{N}_T_{T}_c_{c}_beta_{beta}.pkl")
dim = data['x'].shape[1]
X = [torch.tensor(data['x'][data['t_idx'] == i, :], device = device, dtype = torch.float32) for i in np.sort(np.unique(data['t_idx']))]
ts = torch.linspace(0, data['t_final'], len(X), device = device)

import models, train
sigmas = torch.linspace(0, -2, 5, device = device).exp()
s = models.NCScoreFunc(d = dim, hidden_sizes = [128, 128, 128], activation = torch.nn.ReLU, time_dependent = True).to(device)
s.load_state_dict(torch.load(f"weights/params_NCScoreFunc_default_c_{c}_seed_{seed}_final.pt"))

v_pfi = models.NGMVectorField(dim, hidden_sizes = [64, 64], GL_reg = 0.03).to(device)
opt_pfi = optim.Adam(v_pfi.parameters(), lr = 3e-3)
trace_pfi = train.train_pfi(v_pfi, opt_pfi, s, sigmas,
                            {'X' : X, 't' : ts},
                            {'D' : D, 'reg' : reg_wfr, 'reg_ngm_l1' : 1e-3, 'reg_ngm_l2' : 0},
                            256,
                            options = {'iters' : num_iter, 'print_iter' : 100, 'reg_kind' : 'vf', 'checkpoint_iter' : 1000, 'checkpoint_file' : f"params_PFI_NGMVectorField_{suffix}",
                                       'save_final' : True, 'save_file' : f'params_PFI_NGMVectorField_{suffix}',
                                       'anneal_sigma_iters' : None, 'outdir' : 'weights/', 
                                      'teacher_forcing_iter' : num_iter}, 
                            sample_batch_options = {'replacement' : True,  'add_noise' : False,}, 
                            samplesloss_options={'loss' : 'sinkhorn'}, 
                            odeint_options={'method' : 'euler', 'options' : {'step_size' : data['t_final'] / (2*T)}})

m_ratios = torch.tensor([x.shape[0] / X[0].shape[0] for x in X]).float()
v_upfi = models.ODEFlowGrowth(dim, v_mod = models.NGMVectorField, kwargs_v = {'hidden_sizes' : [64, 64], 'GL_reg' : 0.03}, kwargs_g = {'hidden_sizes' : [64, ]}).to(device);
opt_upfi = optim.Adam(v_upfi.parameters(), lr = 3e-3)
trace_upfi = train.train_upfi(v_upfi, opt_upfi, s, sigmas,
                              {'X' : X, 't' : ts, 'm_ratios' : torch.tensor(m_ratios).to(device),},
                              {'D' : D, 'alpha_wfr' : alpha_wfr, 'reg_wfr' : reg_wfr, 'reg_ngm_l1' : 1e-3, 'reg_ngm_l2' : 0},
                              256,
                              options = {'iters' : num_iter, 'print_iter' : 100, 'reg_kind' : 'vf', 'checkpoint_iter' : 1000, 'checkpoint_file' : f"params_UPFI_ODEFlowGrowth_NGM_{suffix}",
                                       'save_final' : True, 'save_file' : f"params_UPFI_ODEFlowGrowth_NGM_{suffix}",
                                       'anneal_sigma_iters' : None, 'outdir' : 'weights/', 
                                      'teacher_forcing_iter' : num_iter},
                              sample_batch_options = {'replacement' : True,  'add_noise' : False}, 
                              samplesloss_options={'loss' : 'sinkhorn', 'reach' : 5.0}, 
                              odeint_options={'method' : 'euler', 'options' : {'step_size' : data['t_final'] / (2*T)}})

## Reference network 
import pandas as pd
def load_boolODE_reference_network(path, genes):
    df = pd.read_csv(path)
    n_genes = len(genes)
    A_ref = pd.DataFrame(np.zeros((n_genes, n_genes), int), index = genes, columns = genes)
    for i in range(df.shape[0]):
        _i = df.iloc[i, 1]
        _j = df.iloc[i, 0]
        _v = {"+" : 1, "-" : -1}[df.iloc[i, 2]]
        A_ref.loc[_i, _j] = _v
    return A_ref
df = pd.read_csv(f"../../data/HSC/ExpressionData.csv", index_col = 0)
genes = df.index
A_ref = load_boolODE_reference_network(f"../../data/HSC/refNetwork.csv", genes)

genes_reord = ['Gata1', 'Gata2', 'Fog1', 'Eklf', 'Fli1', 'Scl', 'Cebpa', 'Pu1', 'cJun', 'EgrNab', 'Gfi1']
A_ref = A_ref.loc[genes_reord, :].loc[:, genes_reord]

## NGM result 
def maskdiag(A):
    return A * (1 - np.eye(A.shape[0]))
_x = torch.tensor(data['x'], dtype = torch.float32)
A_pfi = v_pfi.net.net.causal_graph(w_threshold=0).T
A_upfi = v_upfi.v_net.net.net.causal_graph(w_threshold=0).T
A_pfi_jac = torch.vmap(torch.func.jacrev(lambda x: v_pfi(None, x)))(_x.to(device)).mean(0).detach().cpu().T
A_upfi_jac = torch.vmap(torch.func.jacrev(lambda x: v_upfi.v_net(None, x)))(_x.to(device)).mean(0).detach().cpu().T

# AUPRC result
import sklearn as sk
from sklearn import metrics
res = {}
for (A, what) in zip([A_pfi, A_upfi, A_pfi_jac, A_upfi_jac], ['pfi', 'upfi', 'pfi_jac', 'upfi_jac']):
    # precision-recall
    y = np.abs(maskdiag(A_ref).values).flatten()
    yhat = maskdiag(np.abs(A)).flatten()
    prec, rec, thresh = sk.metrics.precision_recall_curve(y, yhat)
    avg_prec = sk.metrics.average_precision_score(y, yhat)
    # ROC curve
    fpr, tpr, thresh = sk.metrics.roc_curve(y, yhat)
    auc = sk.metrics.roc_auc_score(y, yhat)
    res[what] = {'prec' : prec, 'rec' : rec, 'thresh' : thresh, 'avg_prec' : avg_prec, 'fpr' : fpr, 'tpr' : tpr, 'auroc' : auc}
torch.save(res, f"evals/auprc_results_ngm_seed_{seed}.pkl")
