import torch
import torch.nn as nn
import numpy as np
import time
from dnaseq import *
from nn_utils import *

from scipy.stats import ranksums

def train_model(model,method,rna_X,atac_X,rna_idx,atac_idx,\
				train_idx,test_idx,\
				optimizer,lmbd,device,num_epochs,batch_size,\
				criterion=nn.MSELoss(),early_stop=True,tol=0.01,verbose=True,\
				positions=None,dnaseq=None,S_0=None,S_1=None):
	
	train_atac_idx = np.array(atac_idx)[train_idx]
	train_rna_idx = np.array(rna_idx)[train_idx]
	test_atac_idx = np.array(atac_idx)[test_idx]
	test_rna_idx = np.array(rna_idx)[test_idx]

	train_loss_list = []
	test_loss_list = []

	torch.autograd.set_detect_anomaly(True)

	for epoch_no in range(num_epochs):
		if 'graph' in method:
			run_epoch_graph(epoch_no,batch_size,rna_X,atac_X,
				train_rna_idx,train_atac_idx,train_idx,S_0,S_1,model,
				optimizer,lmbd,device,criterion=criterion,verbose=True,train=True)
			if epoch_no > 17:
				test_loss,_ = run_epoch_graph(epoch_no,batch_size,rna_X,atac_X,
					train_rna_idx,train_atac_idx,train_idx,S_0,S_1,model,
					optimizer,lmbd,device,criterion=criterion,verbose=True,train=False)

		if epoch_no > 17:
			# train_loss_list.append(train_loss)
			test_loss_list.append(test_loss)

		if epoch_no > 18 and early_stop:
			test_loss_change_1 = (test_loss_list[-2] - test_loss_list[-1])/abs(test_loss_list[-2])
			test_loss_change_2 = (test_loss_list[-3] - test_loss_list[-2])/abs(test_loss_list[-3])
			if test_loss_change_1 < tol and test_loss_change_2 < tol:
				break

def run_epoch(epoch_no,batch_size,rna_X,atac_X,rna_idx,atac_idx,model,
			  optimizer,lmbd,device,criterion=nn.MSELoss(),verbose=True,train=True,
			  positions=None,dnaseq=None):
	
	incurred_loss = 0
	incurred_base_loss = 0
	incurred_penalty = 0
	incurred_smoothness_penalty = 0
	
	batch_split = np.arange(0, len(rna_idx), batch_size)
	coeffs_list = []

	if dnaseq is not None:
		peak_dna_dict = dnaseq['peak']
		gene_dna_dict = dnaseq['gene']

		peak_dna_ohe = np.array([one_hot_encode(peak_dna_dict[idx]) \
			for idx in atac_idx])
		gene_dna_ohe = np.array([one_hot_encode(gene_dna_dict[idx]) \
			for idx in rna_idx])
		peak_dna_ohe = torch.from_numpy(peak_dna_ohe).float().to(device)
		gene_dna_ohe = torch.from_numpy(gene_dna_ohe).float().to(device)


	start = time.time()
	for i in range(len(batch_split)):
		
		batch_rna_x = rna_X[:,rna_idx[i*batch_size:batch_size*(i+1)]].T
		batch_atac_x = atac_X[:,atac_idx[i*batch_size:batch_size*(i+1)]].T
		
		batch_rna_x = np.expand_dims(batch_rna_x,axis=1)
		batch_atac_x = np.expand_dims(batch_atac_x,axis=1)
		
		batch_rna_x = torch.from_numpy(batch_rna_x).float().to(device)
		batch_atac_x = torch.from_numpy(batch_atac_x).float().to(device)
		
		if positions is not None:
			batch_positions = positions[i*batch_size:batch_size*(i+1)]
			preds,coeffs = model(batch_atac_x,batch_rna_x,batch_positions)
		elif dnaseq is not None:
			peak_dna = dnaseq['peak']
			gene_dna = dnaseq['gene']

			batch_peak_dna_ohe = peak_dna_ohe[i*batch_size:batch_size*(i+1)]
			batch_gene_dna_ohe = gene_dna_ohe[i*batch_size:batch_size*(i+1)]

			preds,coeffs = model(batch_atac_x,batch_rna_x,batch_peak_dna_ohe,batch_gene_dna_ohe)
		else:
			preds,coeffs = model(batch_atac_x,batch_rna_x)

		targets = create_outcome_time_series(batch_rna_x.squeeze(),model.kernel_size,1,model.order).squeeze()

		if not train:
			coeffs_list.append(coeffs.detach().cpu().numpy())

		# Loss
		# Base loss
		base_loss = criterion(preds, targets)

		# Sparsity-inducing penalty term
		penalty = torch.mean(torch.mean(torch.norm(coeffs, dim=1, p=1), dim=0))
	
		# Smoothing penalty term
		diff = coeffs[:,1:,:]-coeffs[:,:-1,:]
		penalty_smooth = torch.mean(torch.mean(torch.norm(diff, dim=1, p=2), dim=0))

		gamma = lmbd * 100
		loss = base_loss + lmbd * penalty + gamma * penalty_smooth

		# Incur loss
		incurred_loss += loss.data.cpu().numpy()
		incurred_base_loss += base_loss.data.cpu().numpy()
		incurred_penalty += lmbd * penalty.data.cpu().numpy()
		incurred_smoothness_penalty += gamma * penalty_smooth.data.cpu().numpy()

		if train:
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()
			
	end = time.time()

	if verbose:
		mode = 'train' if train else 'test '
		print('Epoch {} ({:.2f} seconds): {} loss {:.2f},\tsparsity loss {:.2f},\tsmooth loss {:.2f}'.format(epoch_no,\
			end-start,mode,incurred_loss,incurred_penalty,incurred_smoothness_penalty))

	if train:
		return incurred_loss
	else:
		return incurred_loss,np.concatenate(coeffs_list)


def run_epoch_graph(epoch_no,batch_size,rna_X,atac_X,rna_idx,atac_idx,pair_idx,S_0,S_1,model,
			  optimizer,lmbd,device,criterion=nn.MSELoss(),verbose=True,train=True,statistics=False,
			  positions=None,dnaseq=None):

	if not train:
		t_list = []
		w_list = []
		lr_list = []

	incurred_loss = 0
	incurred_base_loss = 0
	incurred_penalty = 0

	batch_split = np.arange(0, len(rna_idx), batch_size)

	stats_criterion = nn.MSELoss(reduction='none') #nn.PoissonNLLLoss(reduction='none')

	start = time.time()
	for i in range(len(batch_split)):
		
		batch_rna_x = rna_X[:,rna_idx[i*batch_size:batch_size*(i+1)]].T
		batch_atac_x = atac_X[:,atac_idx[i*batch_size:batch_size*(i+1)]].T
		
		batch_rna_x = torch.from_numpy(batch_rna_x).float().to(device)
		batch_atac_x = torch.from_numpy(batch_atac_x).float().to(device)

		batch_atac_idx = torch.from_numpy(atac_idx[i*batch_size:batch_size*(i+1)]).long().to(device)
		batch_rna_idx = torch.from_numpy(rna_idx[i*batch_size:batch_size*(i+1)]).long().to(device)
		batch_pair_idx = torch.from_numpy(pair_idx[i*batch_size:batch_size*(i+1)]).long().to(device)

		# S_0 = torch.from_numpy(S_0).float().to(device)
		# S_1 = torch.from_numpy(S_1).float().to(device)

		full_preds,red_preds = model(batch_atac_x,batch_rna_x,\
			batch_atac_idx,batch_rna_idx,batch_pair_idx,S_0,S_1)

		targets = batch_rna_x

		# Loss
		# Base loss
		full_loss = criterion(full_preds, targets)
		red_loss = criterion(red_preds, targets)
		base_loss =  full_loss + red_loss

		# # Sparsity-inducing penalty term
		# penalty = torch.norm(model.C_a, p=1)/torch.numel(model.C_a)*batch_size

		loss = base_loss # + lmbd * penalty

		# Incur loss
		incurred_loss += loss.data.cpu().numpy()
		incurred_base_loss += base_loss.data.cpu().numpy()
		# incurred_penalty += lmbd * penalty.data.cpu().numpy()

		if train:
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

			# if i == 0 and (epoch_no + 1) % 10 == 0:
			# 	print(full_preds[0].max().data.cpu().numpy(),
			# 		full_preds[0].min().data.cpu().numpy(),'-----')
			# 	print(red_preds[0].max().data.cpu().numpy(),
			# 		red_preds[0].min().data.cpu().numpy(),'-----')
			# 	print(targets[0].max().data.cpu().numpy(),
			# 		targets[0].min().data.cpu().numpy(),'-----')

			# 	print(full_preds[0].mean().data.cpu().numpy())
			# 	print(red_preds[0].mean().data.cpu().numpy())
			# 	print(targets[0].mean().data.cpu().numpy())

			# 	full_err = stats_criterion(full_preds,targets).data.cpu().numpy()
			# 	red_err = stats_criterion(red_preds,targets).data.cpu().numpy()

			# 	from scipy.stats import spearmanr

			# 	full_rho_list = []
			# 	red_rho_list = []
			# 	for j in range(full_preds.shape[0]):
			# 		full_rho = spearmanr(full_preds[j].data.cpu().numpy(),targets[j].data.cpu().numpy())[0]
			# 		red_rho = spearmanr(red_preds[j].data.cpu().numpy(),targets[j].data.cpu().numpy())[0]
			# 		full_rho_list.append(full_rho)
			# 		red_rho_list.append(red_rho)
			# 	full_rho_list = np.array(full_rho_list)
			# 	red_rho_list = np.array(red_rho_list)

			# 	print('CORR')
			# 	print((full_rho_list > red_rho_list).mean())
			# 	print((full_err.mean(1) < red_err.mean(1)).mean())

		elif statistics:

			full_err = stats_criterion(full_preds,targets).data.cpu().numpy()
			red_err = stats_criterion(red_preds,targets).data.cpu().numpy()

			for j in range(full_err.shape[0]):
				_,_,ttest_p = welch_ttest(full_err[j],red_err[j],alternative='lesser')
				stat,wilcoxon_p = ranksums(full_err[j],red_err[j],alternative='less')

				lr = red_err[j].sum()/full_err[j].sum()
				t_list.append(ttest_p)
				w_list.append(wilcoxon_p)
				lr_list.append(lr)

	end = time.time()

	if verbose:
		mode = 'train' if train else 'test '
		print('Epoch {} ({:.2f} seconds): {} loss {:.2f},\tsparsity loss {:.2f}'.format(epoch_no,\
			end-start,mode,incurred_loss,incurred_penalty))

	if not train:

		results_dict = {'ttest': t_list,'wilcoxon': w_list, 'lr': lr_list}
		results_dict = {k: np.array(v) for k,v in results_dict.items()}
		return incurred_loss,results_dict
