import torch
import torch.nn as nn
import os
import math

from nn_training import *
from nn_utils import *
from dnaseq import *

def positionalencoding1d(d_model, length):
	"""
	https://github.com/wzlxjtu/PositionalEncoding2D
	:param d_model: dimension of the model
	:param length: length of positions
	:return: length*d_model position matrix
	"""
	if d_model % 2 != 0:
		raise ValueError("Cannot use sin/cos positional encoding with "
						 "odd dim (got dim={:d})".format(d_model))
	pe = torch.zeros(length, d_model)
	position = torch.arange(0, length).unsqueeze(1)
	div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
						 -(math.log(10000.0) / d_model)))
	pe[:, 0::2] = torch.sin(position.float() * div_term)
	pe[:, 1::2] = torch.cos(position.float() * div_term)

	return pe

class SENNGCv2(nn.Module):
	def __init__(self, num_vars, order, hidden_layer_size, num_hidden_layers, device,
				 method="OLS"):
		"""
		Generalised VAR (GVAR) model based on self-explaining neural networks.
		@param num_vars: number of variables (p).
		@param order:  model order (maximum lag, K).
		@param hidden_layer_size: number of units in the hidden layer.
		@param num_hidden_layers: number of hidden layers.
		@param device: Torch device.
		@param method: fitting algorithm (currently, only "OLS" is supported).
		"""
		
		super(SENNGCv2, self).__init__()

		# Networks for amortising generalised coefficient matrices.
		self.coeff_nets = nn.ModuleList()

		# Instantiate coefficient networks
		for k in range(order):
			modules = [nn.Sequential(nn.Linear(num_vars * hidden_layer_size, hidden_layer_size), nn.ReLU())]
			if num_hidden_layers > 1:
				for j in range(num_hidden_layers - 1):
					modules.extend(nn.Sequential(nn.Linear(hidden_layer_size, hidden_layer_size), nn.ReLU()))
			modules.extend(nn.Sequential(nn.Linear(hidden_layer_size, num_vars**2)))
			self.coeff_nets.append(nn.Sequential(*modules))

		# Some bookkeeping
		self.num_vars = num_vars
		self.order = order
		self.hidden_layer_size = hidden_layer_size
		self.num_hidden_layer_size = num_hidden_layers

		self.device = device

		self.method = method

	# Initialisation
	def init_weights(self):
		for m in self.modules():
			nn.init.xavier_normal_(m.weight.data)
			m.bias.data.fill_(0.1)

	# Forward propagation,
	# returns predictions and generalised coefficients corresponding to each prediction
	def forward(self, inputs, hidden_states):
		if inputs[0, :, :].shape != torch.Size([self.order, self.num_vars]):
			print("WARNING: inputs should be of shape BS x K x p")

		coeffs = None
		if self.method == "OLS":
			preds = torch.zeros((inputs.shape[0], self.num_vars)).to(self.device)
			for k in range(self.order):
				coeff_net_k = self.coeff_nets[k]
				coeffs_k = coeff_net_k(hidden_states[:, k, :])
				coeffs_k = torch.reshape(coeffs_k, (inputs.shape[0], self.num_vars, self.num_vars))
				if coeffs is None:
					coeffs = torch.unsqueeze(coeffs_k, 1)
				else:
					coeffs = torch.cat((coeffs, torch.unsqueeze(coeffs_k, 1)), 1)
				coeffs[:, k, :, :] = coeffs_k
				if self.method == "OLS":
					preds += torch.matmul(coeffs_k, inputs[:, k, :].unsqueeze(dim=2)).squeeze()
		elif self.method == "BFT":
			NotImplementedError("Backfitting not implemented yet!")
		else:
			NotImplementedError("Unsupported fitting method!")

		return preds, coeffs
		
class ConvLSTM(nn.Module):
	def __init__(self,T,hidden_size,kernel_size=30):
		super(ConvLSTM, self).__init__()
		
		self.conv1 = nn.Conv1d(1,hidden_size,kernel_size,padding=0)
		self.lstm = nn.LSTM(hidden_size,hidden_size,num_layers=2,bidirectional=False,batch_first=True)
		# self.fc = nn.Linear(hidden_size,1)
		
		# self.pool = pool
		# self.avgpool = nn.AvgPool1d(kernel_size)

		self.relu = nn.ReLU()
		self.sigmoid = nn.Sigmoid()
		self.dropout = nn.Dropout(p=0.5)
		
	def forward(self,x):
		
		out = self.conv1(x)
		# if self.pool:
		# 	out = self.avgpool(out)
		out = self.dropout(out)
		out = self.relu(out)
		out = torch.transpose(out,1,2)
		out,_ = self.lstm(out)
		# out = self.dropout(out)
		out = self.relu(out)
		# out = self.fc(out) #.squeeze()
		
		return out
		
class ConvLSTMAddEmbedding(nn.Module):
	def __init__(self,T,hidden_size,kernel_size=30):
		super(ConvLSTMAddEmbedding, self).__init__()
		
		self.conv1 = nn.Conv1d(1,hidden_size,kernel_size,padding=0)
		self.lstm = nn.LSTM(hidden_size,hidden_size,num_layers=2,bidirectional=False,batch_first=True)
		# self.fc = nn.Linear(hidden_size,1)
		# self.layernorm = nn.LayerNorm(hidden_size)
		
		# self.pool = pool
		# self.avgpool = nn.AvgPool1d(kernel_size)

		self.relu = nn.ReLU()
		self.sigmoid = nn.Sigmoid()
		self.dropout = nn.Dropout(p=0.5)
		
	def forward(self,x,embedding):
		
		out = self.conv1(x)
		out = self.dropout(out)
		# if self.pool:
		# 	out = self.avgpool(out)
		out = self.relu(out)
		out = torch.transpose(out,1,2)
		out,_ = self.lstm(out)
		# out = self.dropout(out)
		out = self.relu(out)
		embedding = embedding.repeat(1,out.shape[1],1)
		out = out + embedding
		# out = self.layernorm(out)
		# out = self.fc(out) #.squeeze()
		
		return out

class AttentionGranger(nn.Module):
	def __init__(self,T,hidden_layer_size,kernel_size,num_hidden_layers,order,device):
		super(AttentionGranger, self).__init__()
		
		self.T = T
		self.hidden_layer_size = hidden_layer_size
		self.kernel_size = kernel_size
		self.num_hidden_layers = num_hidden_layers
		self.order = order
		self.device = device
		
		self.atac_module = ConvLSTM(T,hidden_layer_size,\
									kernel_size=kernel_size)
		self.rna_module = ConvLSTM(T,hidden_layer_size,\
								   kernel_size=kernel_size)
		self.GVAR = SENNGCv2(2,order=order,hidden_layer_size=hidden_layer_size,\
						   num_hidden_layers=num_hidden_layers,device=device)
				
	def forward(self,atac_x,rna_x):

		atac_out = self.atac_module(atac_x)
		rna_out = self.rna_module(rna_x)
		
		out = torch.cat([atac_out,rna_out],axis=2)
		
		preds_list = []
		coeffs_list = []
		for idx in range(atac_x.shape[0]):
			atac_inp = atac_x[idx,:,self.kernel_size-1:]
			rna_inp = rna_x[idx,:,self.kernel_size-1:]
			inp = torch.cat([atac_inp,rna_inp],axis=0).T
			inp = construct_lagged_dataset(inp,self.order).to(self.device)
			predictors = construct_lagged_dataset(out[idx],self.order).to(self.device)
			preds,coeffs = self.GVAR(inp,predictors)

			preds_list.append(preds[:,1]) # keep only predictions for RNA
			coeffs_list.append(coeffs[:,:,0,1]) # keep only coefficients for ATAC to RNA
		
		return torch.stack(preds_list),torch.stack(coeffs_list) 


class AttentionGrangerDistance(nn.Module):
	def __init__(self,T,hidden_layer_size,kernel_size,num_hidden_layers,order,device):
		super(AttentionGrangerDistance, self).__init__()
		
		self.T = T
		self.hidden_layer_size = hidden_layer_size
		self.kernel_size = kernel_size
		self.num_hidden_layers = num_hidden_layers
		self.order = order
		self.device = device
		
		self.atac_module = ConvLSTMAddEmbedding(T,hidden_layer_size,\
									kernel_size=kernel_size)
		self.rna_module = ConvLSTMAddEmbedding(T,hidden_layer_size,\
								   kernel_size=kernel_size)
		self.GVAR = SENNGCv2(2,order=order,hidden_layer_size=hidden_layer_size,\
						   num_hidden_layers=num_hidden_layers,device=device)

		length = 20
		self.positional_embeddings = positionalencoding1d(hidden_layer_size, length)

	def forward(self,atac_x,rna_x,positions):

		positional_encoding = torch.stack([self.positional_embeddings[pos] \
			for pos in positions]).unsqueeze(1)

		atac_out = self.atac_module(atac_x,positional_encoding)
		rna_out = self.rna_module(rna_x,positional_encoding)
		
		out = torch.cat([atac_out,rna_out],axis=2)
		
		preds_list = []
		coeffs_list = []
		for idx in range(atac_x.shape[0]):
			atac_inp = atac_x[idx,:,self.kernel_size-1:]
			rna_inp = rna_x[idx,:,self.kernel_size-1:]
			inp = torch.cat([atac_inp,rna_inp],axis=0).T
			inp = construct_lagged_dataset(inp,self.order).to(self.device)
			predictors = construct_lagged_dataset(out[idx],self.order).to(self.device)
			preds,coeffs = self.GVAR(inp,predictors)
		
			preds_list.append(preds[:,1]) # keep only predictions for RNA
			coeffs_list.append(coeffs[:,:,0,1]) # keep only coefficients for ATAC to RNA

		return torch.stack(preds_list),torch.stack(coeffs_list) 
		

class DNASeqModule(nn.Module):
	def __init__(self,kernel_size,out_channels,hidden_layer_size):
		super(DNASeqModule, self).__init__()
		
		self.conv1 = nn.Conv1d(in_channels=4,out_channels=out_channels,kernel_size=kernel_size)
		self.maxpool = nn.MaxPool1d(kernel_size)
		self.bilstm = nn.LSTM(out_channels,hidden_layer_size,num_layers=2,\
							  bidirectional=True,batch_first=True)
		self.relu = nn.ReLU()
		self.dropout = nn.Dropout(p=0.5)
		
	def forward(self,x):
		
		out = self.conv1(x)
		out = self.maxpool(out)
		out = out.transpose(1,2)
		out = self.dropout(out)
		out = self.relu(out)
		out,_ = self.bilstm(out)
				
		return out


class DNASeqSelfAttention(nn.Module): 
	def __init__(self,kernel_size,out_channels,hidden_layer_size,input_length):
		super(DNASeqSelfAttention, self).__init__()

		self.dnaseq_peak = DNASeqModule(kernel_size,out_channels,hidden_layer_size)
		self.dnaseq_gene = DNASeqModule(kernel_size,out_channels,hidden_layer_size)
		self.L_out = int(np.floor(((input_length - kernel_size + 1) - kernel_size + 1)/kernel_size + 1))
		self.linear = nn.Linear(self.L_out,hidden_layer_size)
	
	def forward(self,peak_dna,gene_dna):
		
		peak_h = self.dnaseq_peak(peak_dna)
		gene_h = self.dnaseq_gene(gene_dna)
		
		attention_map = pairwise_cosine_similarity(peak_h,gene_h)
		max_values,_ = attention_map.max(1)
		
		return self.linear(max_values)


class DNASelfAttentionGranger(nn.Module):
	def __init__(self,T,hidden_layer_size,kernel_size,num_hidden_layers,\
				 dna_length,dna_kernel_size,dna_channels,\
				 order,device):
		super(DNASelfAttentionGranger, self).__init__()

		self.T = T
		self.hidden_layer_size = hidden_layer_size
		self.kernel_size = kernel_size
		self.num_hidden_layers = num_hidden_layers
		self.dna_kernel_size = dna_kernel_size
		self.dna_channels = dna_channels
		self.order = order
		self.device = device

		self.dnaseq_module = DNASeqSelfAttention(dna_kernel_size,dna_channels,hidden_layer_size,dna_length)

		self.atac_module = ConvLSTMAddEmbedding(T,hidden_layer_size,\
									kernel_size=kernel_size)
		self.rna_module = ConvLSTMAddEmbedding(T,hidden_layer_size,\
								   kernel_size=kernel_size)
		self.GVAR = SENNGCv2(2,order=order,hidden_layer_size=hidden_layer_size,\
						   num_hidden_layers=num_hidden_layers,device=device)

	def forward(self,atac_x,rna_x,peak_dna,gene_dna):

		self_attention_embeddings = self.dnaseq_module(peak_dna,gene_dna)
		self_attention_embeddings = self_attention_embeddings.unsqueeze(1)
		
		atac_out = self.atac_module(atac_x,self_attention_embeddings)
		rna_out = self.rna_module(rna_x,self_attention_embeddings)

		out = torch.cat([atac_out,rna_out],axis=2)

		preds_list = []
		coeffs_list = []
		for idx in range(atac_x.shape[0]):
			atac_inp = atac_x[idx,:,self.kernel_size-1:]
			rna_inp = rna_x[idx,:,self.kernel_size-1:]
			inp = torch.cat([atac_inp,rna_inp],axis=0).T
			inp = construct_lagged_dataset(inp,self.order).to(self.device)
			predictors = construct_lagged_dataset(out[idx],self.order).to(self.device)
			preds,coeffs = self.GVAR(inp,predictors)

			preds_list.append(preds[:,1]) 		# keep only predictions for RNA
			coeffs_list.append(coeffs[:,:,0,1]) # keep only coefficients for ATAC to RNA

		return torch.stack(preds_list),torch.stack(coeffs_list) 


#### 

class GraphConvLayer(nn.Module):
	def __init__(self,num_nodes,hidden_size):
		super(GraphConvLayer, self).__init__()
		
		self.W = nn.Parameter(torch.zeros((num_nodes,hidden_size)),requires_grad=True)
		nn.init.xavier_normal_(self.W.data)

		self.num_nodes = num_nodes
		self.hidden_size = hidden_size
		
	def forward(self,x,S):
		return x @ S @ self.W
	
class GraphConvLayer_v2(nn.Module):
	def __init__(self,num_nodes,n_vars):
		super(GraphConvLayer_v2, self).__init__()
		
		self.W = nn.Parameter(torch.ones((n_vars,1)),requires_grad=True)
		self.b = nn.Parameter(torch.ones(n_vars,1),requires_grad=True)
		nn.init.xavier_normal_(self.W.data)
		nn.init.xavier_normal_(self.b.data)

		self.num_nodes = num_nodes
		
	def forward(self,x,S,idx):
		return (x @ S) * self.W[idx] + self.b[idx]

class GraphGrangerModule(nn.Module):
	def __init__(self,num_nodes,num_hidden_layers,atac_idx_all,rna_idx_all,final_activation='exp'):
		super(GraphGrangerModule, self).__init__()

		self.num_hidden_layers = num_hidden_layers
		self.n_peaks = len(set(atac_idx_all))
		self.n_genes = len(set(rna_idx_all))
		self.n_pairs = len(atac_idx_all)

		self.final_activation = final_activation

		for n in range(num_hidden_layers):
			setattr(self,'atac_gcn_{}'.format(n+1),GraphConvLayer_v2(num_nodes,self.n_peaks))
			setattr(self,'rna_gcn_{}'.format(n+1),GraphConvLayer_v2(num_nodes,self.n_genes))
			setattr(self,'rna_gcn_r_{}'.format(n+1),GraphConvLayer_v2(num_nodes,self.n_genes))

		self.C_a = nn.Parameter(torch.ones(self.n_pairs,1),requires_grad=True)
		nn.init.xavier_normal_(self.C_a.data)

		self.relu = nn.Tanh()
		self.sigmoid = nn.Sigmoid()
		
	def forward(self,atac_x,rna_x,atac_idx,rna_idx,pair_idx,S_0,S_1):
		
		atac_out_list = []
		rna_out_list = []
		rna_r_out_list = []

		atac_out = self.atac_gcn_1(atac_x,S_0,atac_idx)
		rna_out = self.rna_gcn_1(rna_x,S_0,rna_idx)
		rna_r_out = self.rna_gcn_r_1(rna_x,S_0,rna_idx)

		atac_out_list.append(atac_out)
		rna_out_list.append(rna_out)
		rna_r_out_list.append(rna_r_out)

		for n in range(1,self.num_hidden_layers):

			atac_out = self.relu(atac_out)
			rna_out = self.relu(rna_out)
			rna_r_out = self.relu(rna_r_out)

			atac_out = getattr(self,'atac_gcn_{}'.format(n+1))(atac_out,S_1,atac_idx)
			rna_out = getattr(self,'rna_gcn_{}'.format(n+1))(rna_out,S_1,rna_idx)
			rna_r_out = getattr(self,'rna_gcn_r_{}'.format(n+1))(rna_r_out,S_1,rna_idx)

			atac_out_list.append(atac_out)
			rna_out_list.append(rna_out)
			rna_r_out_list.append(rna_r_out)

		atac_out = torch.stack(atac_out_list,axis=0).mean(0)
		rna_out = torch.stack(rna_out_list,axis=0).mean(0)
		rna_r_out = torch.stack(rna_r_out_list,axis=0).mean(0)

		full_output = rna_out.squeeze() \
			+ atac_out.squeeze()*self.C_a[pair_idx] \

		reduced_output = rna_r_out.squeeze()

		if self.final_activation == 'exp':
			full_output = torch.exp(full_output)
			reduced_output = torch.exp(reduced_output)
		elif self.final_activation == 'sigmoid':
			full_output = self.sigmoid(full_output)
			reduced_output = self.sigmoid(reduced_output)

		return full_output,reduced_output
