import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import csv
import scanpy as sc
import h5py
import argparse
import time

import sys
sys.path.append('{}/tools/GVAR'.format(base_dir))
from training import *

from statsmodels.tsa.vector_ar.var_model import VAR
from statsmodels.tsa.stattools import grangercausalitytests
from statsmodels.tsa.stattools import adfuller
from scipy.stats import ranksums
from scipy.sparse import csr_matrix

from nn_models import *
from nn_training import *
from nn_utils import *

def run_GVAR(data):

	order = 5 # max lag order
	h_size = 8 # hidden layer size
	lmbd = 1 # sparsity-inducing penalty
	gamma = 1 # time smoothing penalty
	num_hidden_layers = 2

	causal_struct_estimate, _ = training_procedure(data,order=order,hidden_layer_size=h_size,\
		end_epoch=5,batch_size=512,lmbd=lmbd,gamma=gamma,num_hidden_layers=num_hidden_layers,\
		verbose=False)

	return causal_struct_estimate[:-1,-1] #,coeffs

def run_granger(data):

	stat_test = 'lrtest'

	maxlag = 10

	results = grangercausalitytests(data,maxlag=maxlag,verbose=False)

	aic_results = []
	for lag in range(1,maxlag):
		aic = 4*lag-2*results[lag][1][1].llf
		aic_results.append(aic)
	opt_lag = np.arange(1,maxlag)[np.argmin(aic_results)]

	p_val = results[opt_lag][0][stat_test][1]

	return p_val

def grouped_corr(data):

	from scipy.stats import spearmanr

	rho,p = spearmanr(data)

	return rho[:-1,-1]

if __name__ == '__main__':

	parser = argparse.ArgumentParser()
	parser.add_argument('-m','--method',dest='method',type=str,default='baseline')
	parser.add_argument('-l','--lmbd',dest='lmbd',type=float,default=1)
	parser.add_argument('-d','--dataset',dest='dataset',type=str,default=1)
	parser.add_argument('-nn','--n_neighbors',dest='n_neighbors',type=int,default=15)
	parser.add_argument('-nl','--n_layers',dest='n_layers',type=int,default=5)

	args = parser.parse_args()

	print('Loading data...')
	print(args.method)

	data_dir = '{}/datasets/{}'.format(base_dir,args.dataset)

	if 'sketch' in args.method:
		atac_adata = sc.read(os.path.join(data_dir,'{}.atac.sketch.h5ad'.format(args.dataset)))
		rna_adata = sc.read(os.path.join(data_dir,'{}.rna.sketch.h5ad'.format(args.dataset)))
	else:
		atac_adata = sc.read(os.path.join(data_dir,'{}.atac.h5ad'.format(args.dataset)))
		rna_adata = sc.read(os.path.join(data_dir,'{}.rna.h5ad'.format(args.dataset)))
	idx_gene_dict = {i:g for i,g in enumerate(rna_adata.var.index.values)}

	if 'rna_counts' in args.method:
		X = np.exp(rna_adata.X.toarray())-1
		n_counts = X.sum(1)
		X = (X.T/n_counts*np.median(rna_adata.obs['n_counts'])).T
		rna_adata.X = csr_matrix(X)
	elif 'max_scale' in args.method:
		X_max = rna_adata.X.max(0).toarray().squeeze()
		X_max[X_max == 0] = 1
		rna_adata.X = csr_matrix(rna_adata.X / X_max)

	if 'citation' in args.dataset:
		pass
	elif 'binarize' not in args.method:
		sc.pp.normalize_total(atac_adata,target_sum=1e4)
		sc.pp.log1p(atac_adata)
	else:
		atac_adata.X = atac_adata.X.astype(bool).astype(float)

	torch.cuda.empty_cache()

	if 'baseline' in args.method:

		data_dir = '{}/datasets/{}'.format(base_dir,args.dataset)

		atac_idx = np.loadtxt(os.path.join(data_dir,'atac_idx.txt'),dtype=int)
		rna_idx = np.loadtxt(os.path.join(data_dir,'rna_idx.txt'),dtype=int)

		print('{} examples'.format(len(rna_idx)))

		if 'bin' in args.method:
			print('binning...')
		sorted_inds = np.argsort(rna_adata.obs['dpt_pseudotime'].values)
		gvar_scores = []
		granger = []
		for gene_ind,atac_ind in zip(*[rna_idx,atac_idx]):
			atac_X = atac_adata.X[:,atac_ind].toarray()
			data = np.concatenate([rna_adata.X[:,[gene_ind]].T.toarray(),atac_X.T]).T
			data = data[sorted_inds]

			# bin data
			if 'bin' in args.method:
				bin_size = int(data.shape[0]/100)
				data = np.array([data[i:i+bin_size].mean(0) for i in range(0,data.shape[0],bin_size)])

			# gvar_scores.append(run_GVAR(data))
			granger.append(run_granger(data))

			if len(granger) % 1000 == 0:
				print(len(granger)/len(atac_idx))

		# gvar_scores = np.array(gvar_scores)
		granger = np.array(granger)

		save_dir = '{}/results/tests_nn/{}'.format(base_dir,args.dataset)
		# np.savetxt(os.path.join(save_dir,'GVAR.results.txt'),gvar_scores)
		np.savetxt(os.path.join(save_dir,'{}.bin.results.txt'.format(args.method)),granger)


	elif 'graph' in args.method:

		device = "cuda" if torch.cuda.is_available() else "cpu"

		data_dir = '{}/datasets/{}'.format(base_dir,args.dataset)

		if '-5000-' not in args.dataset:
			if 'all' in args.method:
				atac_idx = np.loadtxt(os.path.join(data_dir,'atac_idx.txt'),dtype=int)
				rna_idx = np.loadtxt(os.path.join(data_dir,'rna_idx.txt'),dtype=int)
			elif 'hv_genes' in args.method:
				atac_idx = np.loadtxt(os.path.join(data_dir,'hv_atac_idx.txt'),dtype=int)
				rna_idx = np.loadtxt(os.path.join(data_dir,'hv_rna_idx.txt'),dtype=int)
			else:
				atac_idx = np.loadtxt(os.path.join(data_dir,'sample_atac_idx.txt'),dtype=int)
				rna_idx = np.loadtxt(os.path.join(data_dir,'sample_rna_idx.txt'),dtype=int)
		else:
			atac_idx = np.loadtxt(os.path.join(data_dir,'atac_idx.txt'),dtype=int)
			rna_idx = np.loadtxt(os.path.join(data_dir,'rna_idx.txt'),dtype=int)

		# limit to rna, atac indices used
		atac_X = atac_adata.X[:,sorted(list(set(atac_idx)))].toarray()
		rna_X = rna_adata.X[:,sorted(list(set(rna_idx)))].toarray()

		print(atac_X.shape,rna_X.shape,rna_X.max(),atac_X.max())

		sorted_atac_idx = sorted(list(set(atac_idx)))
		sorted_rna_idx = sorted(list(set(rna_idx)))

		atac_idx_map = {idx1:idx2 for idx1,idx2 in \
			zip(*[sorted(list(set(atac_idx))),range(atac_X.shape[1])])}
		rna_idx_map = {idx1:idx2 for idx1,idx2 in \
			zip(*[sorted(list(set(rna_idx))),range(rna_X.shape[1])])}

		atac_idx = np.array([atac_idx_map[idx] for idx in atac_idx])
		rna_idx = np.array([rna_idx_map[idx] for idx in rna_idx])

		print('KNN + Pseudotime...')
		if 'citation' in args.dataset:
			from scipy.io import mmread
			S_0 = mmread(os.path.join(data_dir,'citation.adj_mat.mtx')).T.toarray()
			S_0_sum = S_0.sum(0)
			S_0_sum[S_0_sum == 0] = 1
			S_0 = S_0/S_0_sum
			S_1 = S_0.copy()
		else:
			dpt = '-5000-' not in args.dataset
			S_0,S_1 = construct_transition_matrices(rna_adata,args.n_neighbors,backward=True,dpt=dpt)

		num_nodes = rna_X.shape[0]
		n_peaks = atac_X.shape[1]
		n_genes = rna_X.shape[1]
		num_hidden_layers = args.n_layers

		S_0 = torch.FloatTensor(S_0).to(device)
		S_1 = torch.FloatTensor(S_1).to(device)

		if args.dataset == 'citation':
			final_activation = 'linear'
		elif args.dataset == 'citation2':
			final_activation = 'sigmoid'
		else:
			final_activation = 'exp'

		if 'trial1' in args.method:
			print('seed 1')
			torch.manual_seed(1)
		elif 'trial2' in args.method:
			print('seed 2')
			torch.manual_seed(2)
		elif 'trial3' in args.method:
			print('seed 3')
			torch.manual_seed(3)
		elif 'trial4' in args.method:
			print('seed 4')
			torch.manual_seed(4)

		model = GraphGrangerModule(num_nodes,num_hidden_layers,atac_idx,rna_idx,
			final_activation=final_activation)
		model.to(device)

		batch_size = 1024

		initial_learning_rate = 0.001
		beta_1 = 0.9
		beta_2 = 0.999
		lmbd = args.lmbd
		max_epochs = 20 # if 'citation' not in args.dataset else 200

		train_idx,test_idx = train_test_split(len(rna_idx),train_prop=1.)

		optimizer = torch.optim.Adam(params=model.parameters(), 
			lr=initial_learning_rate, betas=(beta_1, beta_2))

		if 'mseloss' in args.method:
			criterion = nn.MSELoss()
		elif args.dataset == 'citation2':
			criterion = nn.BCELoss()
		elif args.dataset == 'citation':
			criterion = nn.MSELoss()
		else:
			criterion = nn.PoissonNLLLoss(log_input=False,full=True) #nn.MSELoss()

		start = time.time()
		print('Training...{} examples'.format(len(rna_idx)))
		train_model(model,args.method,rna_X,atac_X,rna_idx,atac_idx,\
					train_idx,test_idx,optimizer,lmbd,device,max_epochs,batch_size,\
					criterion=criterion,early_stop=True,tol=0.1/len(rna_idx),verbose=True,S_0=S_0,S_1=S_1)

		print('Testing...')
		_,results_dict = run_epoch_graph('Inference',batch_size,
			rna_X,atac_X,rna_idx,atac_idx,np.arange(len(atac_idx)),S_0,S_1,model,
			optimizer,lmbd,device,criterion=criterion,verbose=True,train=False,statistics=True)

		save_dir = '{}/results/tests_nn/{}'.format(base_dir,args.dataset)

		for k,result in results_dict.items():
			np.savetxt(os.path.join(save_dir,'{}.{}layers.nn{}.{}.p.txt'.format(
				args.method,args.n_layers,args.n_neighbors,k)),result)

		# save model
		model.to('cpu')
		torch.save(model.state_dict(),os.path.join(save_dir,'{}.{}layers.nn{}.model_weights.pth'.format(
			args.method,args.n_layers,args.n_neighbors)))

		print('Total Time: {} seconds'.format(time.time()-start))

