import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
from sklearn.metrics import roc_auc_score, silhouette_samples
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
import pdb
import pickle
import numpy as np
import torch
import argparse

from src.ewot import EWOT
from src.lwot import LWOT

parser = argparse.ArgumentParser()
parser.add_argument(
	"--outpath", default="experiment_data/shape_correspondence/", 
	help="outpath of experiment results, both data and figures"
)
parser.add_argument(
	"--datapath", default="data/shape_data/test-sets", 
	help="datapath of shape data"
)
parser.add_argument(
	'--lwot', default=False, action=argparse.BooleanOptionalAction, 
	help="run lwot if True or run ewot if False"
)
parser.add_argument(
	"--wavelet_kernel", default="heat_kernel", 
	help="wavelet kernel to use"
)
parser.add_argument(
	"--n_samples", default=1000, 
	help="number of points to sample from each shape"
)
parser.add_argument('--plot', default=True, action=argparse.BooleanOptionalAction, help="plot results if True")
args = parser.parse_args()

"""
Evaluation code below is largely adapted from SCOT (Ritambhara Singh, Pinar Demetci, Rebecca Santorella)
"""
def calc_frac_idx(x1_mat,x2_mat, true_dict):
	"""
	Returns fraction closer than true match for each sample (as an array)
	"""
	fracs = []
	x = []
	nsamp = x1_mat.shape[0]
	rank=0
	for row_idx in range(nsamp):
		euc_dist = torch.sqrt(torch.sum(np.square(np.subtract(x1_mat[row_idx,:], x2_mat)), axis=1))
		true_nbr = euc_dist[row_idx]
		sort_euc_dist = sorted(euc_dist)
		rank =sort_euc_dist.index(true_nbr)
		frac = float(rank)/(nsamp -1)

		fracs.append(frac)
		x.append(row_idx+1)

	return fracs,x

def calc_domainAveraged_FOSCTTM(x1_mat, x2_mat, true_dict):
	"""
	Outputs average FOSCTTM measure (averaged over both domains)
	Get the fraction matched for all data points in both directions
	Averages the fractions in both directions for each data point
	"""
	fracs1,xs = calc_frac_idx(x1_mat, x2_mat, true_dict)
	fracs2,xs = calc_frac_idx(x2_mat, x1_mat, {value: key for key, value in true_dict.items()})
	fracs = []
	for i in range(len(fracs1)):
		fracs.append((fracs1[i]+fracs2[i])/2)  
	return fracs

def calc_sil(x1_mat,x2_mat,x1_lab,x2_lab):
	"""
	Returns silhouette score for datasets with cell clusters
	"""
	sil = []
	sil_d0 = []
	sil_d3 = []
	sil_d7 = []
	sil_d11 = []
	sil_npc = []

	x = np.concatenate((x1_mat,x2_mat))
	lab = np.concatenate((x1_lab,x2_lab))

	sil_score = silhouette_samples(x,lab)

	nsamp = x.shape[0]
	for i in range(nsamp):
		if(lab[i]==1):
			sil_d0.append(sil_score[i])
		elif(lab[i]==2):
			sil_d3.append(sil_score[i])
		elif(lab[i]==3):
			sil_d7.append(sil_score[i])
		elif(lab[i]==4):
			sil_d11.append(sil_score[i])
		elif(lab[i]==5):
			sil_npc.append(sil_score[i])

	avg = np.mean(sil_score)
	d0 = sum(sil_d0)/len(sil_d0)
	d3 = sum(sil_d3)/len(sil_d3)
	d7 = sum(sil_d7)/len(sil_d7)
	d11 = sum(sil_d11)/len(sil_d11)
	npc = sum(sil_npc)/len(sil_npc)
	
	return avg,d0,d3,d7,d11,npc

def binarize_labels(label,x):
	"""
	Helper function for calc_auc
	"""
	bin_lab = np.array([1] * len(x))
	idx = np.where(x == label)
	
	bin_lab[idx] = 0
	return bin_lab
	
def calc_auc(x1_mat, x2_mat, x1_lab, x2_lab):
	"""
	calculate avg. ROC AUC scores for transformed data when there are >=2 number of clusters.
	"""
	nsamp = x1_mat.shape[0]
	
	auc = []
	auc_d0 = []
	auc_d3 = []
	auc_d7 = []
	auc_d11 = []
	auc_npc = []
	
	for row_idx in range(nsamp):
		euc_dist = np.sqrt(np.sum(np.square(np.subtract(x1_mat[row_idx,:], x2_mat)), axis=1))
		y_scores = euc_dist
		y_true = binarize_labels(x1_lab[row_idx],x2_lab)
				
		auc_score = roc_auc_score(y_true, y_scores)
		auc.append(auc_score)
	
		if(x1_lab[row_idx]==0):
			auc_d0.append(auc_score)
		elif(x1_lab[row_idx]==1):
			auc_d3.append(auc_score)
		elif(x1_lab[row_idx]==2):
			auc_d7.append(auc_score)
		elif(x1_lab[row_idx]==3):
			auc_d11.append(auc_score)
		elif(x1_lab[row_idx]==4):
			auc_npc.append(auc_score)
		
	avg = sum(auc)/len(auc)
	d0 = sum(auc_d0)/len(auc_d0)
	d3 = sum(auc_d3)/len(auc_d3)
	d7 = sum(auc_d7)/len(auc_d7)
	d11 = sum(auc_d11)/len(auc_d11)
	npc = sum(auc_npc)/len(auc_npc)
	
	return avg,d0,d3,d7,d11,npc

def transfer_accuracy(domain1, domain2, type1, type2, n=5):
	"""
	Metric from UnionCom: "Label Transfer Accuracy"
	"""
	knn = KNeighborsClassifier(n_neighbors=n)
	knn.fit(domain2, type2)
	type1_predict = knn.predict(domain1)
	# np.savetxt("type1_predict.txt", type1_predict)
	count = 0
	for label1, label2 in zip(type1_predict, type1):
		if label1 == label2:
			count += 1
	return count / len(type1)

# X1 = np.load("data/SNARE/SNAREseq_atac_feat.npy") 
# X2 = np.load("data/SNARE/SNAREseq_rna_feat.npy")

# cellTypes_atac=np.loadtxt("data/SNARE/SNAREseq_atac_types.txt")
# cellTypes_rna=np.loadtxt("data/SNARE/SNAREseq_rna_types.txt")
# # np.random.seed(0)
# # X_2_permutation_indices = np.random.permutation(X2.shape[0])
# # # X_2_permutation_matrix = np.eye(X2.shape[0])[X_2_permutation_indices].T
# # X_2_permutation_matrix = np.eye(X2.shape[0])
# # X2_permuted = (X_2_permutation_matrix @ X2)
# # cellTypes_methyl = X_2_permutation_matrix @ cellTypes_methyl

# # true_dict = {X_1_permutation_indices[i]: X_2_permutation_indices[i] for i in range(X_1_permutation_indices.shape[0])}
# n_scales = 20
# wot = WaveletOT(X1, X2, n_scales=n_scales)
# # raise
# wot.solve()

# aligned_point_X1 = wot.project(to_X2=False)
# aligned_point_X2 = wot.project()

# lta_x1_x2 = transfer_accuracy(aligned_point_X2, X2, cellTypes_rna, cellTypes_atac)
# lta_x2_x1 = transfer_accuracy(aligned_point_X1, X1, cellTypes_atac, cellTypes_rna)
# print(np.mean([lta_x1_x2, lta_x2_x1]))
# FOSCTTM_X1 = np.mean(calc_domainAveraged_FOSCTTM(X1, aligned_point_X1, {}))
# FOSCTTM_X2 = np.mean(calc_domainAveraged_FOSCTTM(X2, aligned_point_X2, {}))
# print(np.mean([FOSCTTM_X1, FOSCTTM_X2]))

# # Reduce the dimensionality of the aligned domains to two (2D) via PCA for the sake of visualization:
# pca=PCA(n_components=2)
# Xy_pca=pca.fit_transform(np.concatenate((aligned_point_X2, X2), axis=0))
# X_pca=Xy_pca[0: 1047,]
# y_pca=Xy_pca[1047:,]

# # #Plot aligned domains, samples colored by domain identity:
# plt.scatter(X_pca[:,0], X_pca[:,1], c="k", s=15, label="Chromatin Accessibility", alpha=0.4)
# plt.scatter(y_pca[:,0], y_pca[:,1], c="r", s=15, label="Gene Expression", alpha=0.4)
# plt.legend()
# plt.title("Colored based on domains")
# plt.show()
# plt.savefig("pca_SNARE.png")


# cellTypes_atac=np.loadtxt("data/SNARE/SNAREseq_atac_types.txt")
# cellTypes_rna=np.loadtxt("data/SNARE/SNAREseq_rna_types.txt")

# colormap = plt.get_cmap('rainbow', 4) 
# plt.scatter(X_pca[:,0], X_pca[:,1], c=cellTypes_atac, s=15, cmap=colormap)
# plt.scatter(y_pca[:,0], y_pca[:,1], c=cellTypes_rna, s=15, cmap=colormap)
# # plt.colorbar()
# cbar=plt.colorbar()

# # approximately center the colors on the colorbar when adding cell type labels
# tick_locs = (np.arange(1,5)+0.75) *3/4 
# cbar.set_ticks(tick_locs)
# cbar.set_ticklabels(["H1", "GM", "BJ", "K562"]) #cell-type labels
# plt.title("Colored based on cell type identity")
# plt.show()

X1 = np.genfromtxt("data/scGEM/scGEM_expression.csv", delimiter=",")
X2 =  np.genfromtxt("data/scGEM/scGEM_methylation.csv", delimiter=",")

cellTypes_rna=np.loadtxt("data/scGEM/scGEM_typeExpression.txt")
cellTypes_methyl=np.loadtxt("data/scGEM/scGEM_typeMethylation.txt")

# props = [1, 1, 0.7, 0.5]

# index = np.concatenate([np.where(cellTypes_methyl == i)[0][0:int((cellTypes_methyl == i).sum() * props[i-1])] for i in range(1,len(props)+1)])

# X2 = X2[index, :]
# cellTypes_methyl = cellTypes_methyl[index]

X_1_permutation_indices = np.random.permutation(X1.shape[0])
# X_1_permutation_matrix = np.eye(X1.shape[0])[X_1_permutation_indices]
X_1_permutation_matrix = np.eye(X1.shape[0])
X1 = X_1_permutation_matrix @ X1


X_2_permutation_indices = np.random.permutation(X2.shape[0])
# X_2_permutation_matrix = np.eye(X2.shape[0])[X_2_permutation_indices]
X_2_permutation_matrix = np.eye(X2.shape[0])
# X2_permuted = (X_2_permutation_matrix @ X2)
# cellTypes_methyl = X_2_permutation_matrix @ cellTypes_methyl

# true_dict = {X_1_permutation_indices[i]: X_2_permutation_indices[i] for i in range(X_1_permutation_indices.shape[0])}
n_scales = 20
wot = WaveletOT(X1, X_2_permutation_matrix @ X2, n_scales=n_scales, coupling=X_2_permutation_matrix)
# raise
wot.solve()

aligned_point_X1 = wot.project(to_X2=False)
aligned_point_X2 = wot.project()

lta_x1_x2 = transfer_accuracy(aligned_point_X2, X2, cellTypes_methyl, cellTypes_rna)
# lta_x2_x1 = transfer_accuracy(aligned_point_X1, X1[index, :], cellTypes_rna[index], cellTypes_methyl)
print(lta_x1_x2)
# # print(np.mean([lta_x1_x2, lta_x2_x1]))
# raise
# FOSCTTM_X1 = np.mean(calc_domainAveraged_FOSCTTM(X_2_permutation_matrix @ X1, aligned_point_X1, {}))
# FOSCTTM_X2 = np.mean(calc_domainAveraged_FOSCTTM(X2, aligned_point_X2, {}))
# print(np.mean([FOSCTTM_X1, FOSCTTM_X2]))

pca=PCA(n_components=2)
Xy_pca=pca.fit_transform(np.concatenate((aligned_point_X1, X1), axis=0))
X_pca=Xy_pca[0: 177,]
y_pca=Xy_pca[177:,]

# correlation_coefficient = np.corrcoef(X_pca[:, 0], y_pca[:,0])[0, 1]

# slope, intercept = np.polyfit(X_pca[:, 0], y_pca[:,0], 1)
# line_of_best_fit = slope * X_pca[:, 0] + intercept

# # # Plot aligned domains, samples colored by domain identity:
# plt.scatter(X_pca[:,0], y_pca[:,0], c="b", s=15, alpha=1)
# plt.plot(X_pca[:,0], line_of_best_fit, color="red", label=f'Correlation: {correlation_coefficient:.2f}')


# # plt.scatter(y_pca[:,0], y_pca[:,1], c="r", s=15, label="DNA Methylation", alpha=0.4)
# plt.title("Projected Embedding vs Ground Truth Embedding")
# plt.xlabel("PC1 Ground Truth")
# plt.ylabel("PC1 Predicted")
# plt.legend()

# plt.show()
# plt.savefig("pca_scGEM.png")

colormap = plt.get_cmap('rainbow', 5) 
plt.scatter(X_pca[:,0], X_pca[:,1], c=cellTypes_rna, s=15, cmap=colormap, alpha=0.8)
plt.scatter(y_pca[:,0], y_pca[:,1], c=cellTypes_methyl, s=15, cmap=colormap, alpha=0.8)
cbar=plt.colorbar()

# approximately center the colors on the colorbar when adding cell type labels
tick_locs = (np.arange(0,5)+1.3) *4/5 
cbar.set_ticks(tick_locs)
cbar.set_ticklabels(["BJ", "d8", "d16T+", "d24T+", "iPS"]) #cell-type labels
plt.title("Colored based on cell type identity")
plt.show()
plt.savefig("cell_iden_scGEM.png")