import numpy as np
import os
import torch
from scipy.stats import t
from scipy.sparse import csr_matrix,identity
import scanpy as sc

def construct_lagged_dataset(data,order):
	return torch.stack([data[k:-order+k] for k in range(order)],axis=1)

def create_outcome_time_series(data,kernel_size,stride,order):
	
	if stride == 1:
		return data[:,kernel_size + order - 1:]
	else:
		new_data = []
		for idx in range(kernel_size + order - 1,data.shape[1],stride):
			new_data.append(data[:,idx:idx + stride].mean(1))

		return torch.stack(new_data).T

def pairwise_cosine_similarity(a,b):
	a_norm = a / a.norm(dim=2).unsqueeze(2) 
	b_norm = b / b.norm(dim=2).unsqueeze(2) 

	return torch.bmm(a_norm,b_norm.transpose(1,2))


def corr2_coeff(A, B):
	# Rowwise mean of input arrays & subtract from input arrays themeselves
	A_mA = A - A.mean(1)[:, None]
	B_mB = B - B.mean(1)[:, None]

	# Sum of squares across rows
	ssA = (A_mA**2).sum(1)
	ssB = (B_mB**2).sum(1)

	# Finally get corr coeff
	return np.dot(A_mA, B_mB.T) / np.sqrt(np.dot(ssA[:, None],ssB[None]))

def coeff_to_causality(coeffs):
	# return np.max(np.median(np.abs(coeffs), axis=1), axis=1)
	return np.max(np.median(np.abs(coeffs), axis=1), axis=1)

def get_lags(coeffs):
	# return np.argmax(np.median(np.abs(coeffs), axis=1), axis=1)	
	return np.argmax(np.median(np.abs(coeffs), axis=1), axis=1)	

# def coeff_to_causality(coeffs):
# 	return np.max(np.max(coeffs, axis=2), axis=1)

def train_test_split(idx_length,train_prop=0.8,seed=1):

	np.random.seed(seed)

	shuffled_idx_list = np.arange(idx_length)
	np.random.shuffle(shuffled_idx_list)

	train_idx = shuffled_idx_list[0:int(train_prop*idx_length)]
	test_idx = shuffled_idx_list[int(train_prop*idx_length):]

	return train_idx,test_idx

def welch_ttest(x1, x2,alternative):

	n1 = x1.size
	n2 = x2.size
	m1 = np.mean(x1)
	m2 = np.mean(x2)
	v1 = np.var(x1, ddof=1)
	v2 = np.var(x2, ddof=1)
	tstat = (m1 - m2) / np.sqrt(v1 / n1 + v2 / n2)
	df = (v1 / n1 + v2 / n2)**2 / (v1**2 / (n1**2 * (n1 - 1)) + v2**2 / (n2**2 * (n2 - 1)))
	if alternative == "equal":
		p = 2 * t.cdf(-abs(tstat), df)
	if alternative == "lesser":
		p = t.cdf(tstat, df)
	if alternative == "greater":
		p = 1-t.cdf(tstat, df)
	return tstat, df, p

def coo_to_sparse_tensor(coo):

	values = coo.data
	indices = np.vstack((coo.row, coo.col))

	i = torch.LongTensor(indices)
	v = torch.FloatTensor(values)
	shape = coo.shape

	return torch.sparse.FloatTensor(i, v, torch.Size(shape))

def csr_to_sparse_tensor(csr):

	col_indices = torch.tensor(csr.indices)
	crow_indices = torch.tensor(csr.indptr)
	values = torch.tensor(csr.data)

	return torch._sparse_csr_tensor(crow_indices, col_indices, values,dtype=torch.float)

def construct_transition_matrices(adata,n_neighbors,matrix_power=1,backward=True,dpt=True):
	
	# create KNN graph + infer pseudotime
	sc.pp.neighbors(adata,use_rep='X_schema',n_neighbors=n_neighbors)

	if dpt:
		sc.tl.dpt(adata)

	A = adata.obsp['distances'].astype(bool).astype(float)
	D = np.sign(adata.obs['dpt_pseudotime'].values[:,None] - adata.obs['dpt_pseudotime'].values).T
	if backward:
		D *= -1
	D = (D == 1).astype(float)
	D = (A.toarray()*D).astype(bool).astype(float)

	D_0 = D.copy()
	D_1 = D.copy() + np.eye(D.shape[0]) #identity(D.shape[0])
	S_0 = D_0.copy()
	S_1 = D_1.copy()

	D_0_sum = D_0.sum(1)
	D_0_sum[D_0_sum == 0] = 1
	S_0 = (S_0.T/D_0_sum)
	S_1 = (S_1.T/D_1.sum(1))

	S_0 = np.linalg.matrix_power(S_0,matrix_power)
	S_1 = np.linalg.matrix_power(S_1,matrix_power)
	
	return S_0,S_1