import numpy as np
from os.path import expanduser
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from os.path import expanduser
from os import listdir
import os
import MDAnalysis as mda
import matplotlib.pyplot as plt
from MDAnalysis.topology.guessers import guess_masses
import multipers as mp
# from numba import njit
from tqdm import tqdm



DATASET_PATH = expanduser("~/Datasets/")
JC_path = DATASET_PATH + "Cleves-Jain/"
DUDE_path = DATASET_PATH + "DUD-E/"


#pathes = get_data_path()
#imgs = apply_pipeline(pathes=pathes, pipeline=pipeline_img)
#distances_to_letter, ytest = img_distances(imgs)


def _get_mols_in_path(folder):
	with open(folder+"/TargetList", "r") as f:
		train_data =  [folder + "/" + mol.strip() for mol in f.readlines()]
	criterion = lambda dataset : dataset.endswith(".mol2") and not dataset.startswith("final") and dataset not in train_data
	test_data = [folder + "/" + dataset for dataset in listdir(folder) if criterion(folder + "/" + dataset)]
	return train_data, test_data
def get_data_path_JC(type="dict"):
	if type == "dict": out = {}
	elif type == "list": out = []
	else: raise TypeError(f"Type {out} not supported")
	for stuff in listdir(JC_path):
		if stuff.startswith("target_"):
			current_letter = stuff[-1]
			to_add = _get_mols_in_path(JC_path + stuff)
			if type == "dict":	out[current_letter] = to_add
			elif type == "list": out.append(to_add)
	decoy_folder = JC_path + "RognanRing850/"
	to_add = [decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")]
	if type == "dict":	out["decoy"] = to_add
	elif type == "list": out.append(to_add)
	return out
def get_all_JC_path():
	out = []
	for stuff in listdir(JC_path):
		if stuff.startswith("target_"):
			train_data, test_data =  _get_mols_in_path(JC_path + stuff)
			out += train_data
			out += test_data
	decoy_folder = JC_path + "RognanRing850/"
	out +=[decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")]
	return out
		

def split_multimol(path:str, mol_name:str, out_folder_name:str = "splitted", enforce_charges:bool=False):
	with open(path + mol_name, "r") as f:
		lines = f.readlines()
	splitted_mols = []
	index = 0
	for i,line in enumerate(lines):
		is_last = i == len(lines)-1
		if line.strip() == "@<TRIPOS>MOLECULE" or is_last:
			if i != index:
				molecule = "".join(lines[index:i + is_last])
				if enforce_charges:
					# print(f"Replaced molecule {i}")
					molecule = molecule.replace("NO_CHARGES","USER_CHARGES")
					# print(molecule)
					# return
				index = i
				splitted_mols.append(molecule)
	if not os.path.exists(path + out_folder_name):
		os.mkdir(path + out_folder_name)
	for i,mol in enumerate(splitted_mols):
		with open(path + out_folder_name + f"/{i}.mol2", "w") as f:
			f.write(mol)
	return [path+out_folder_name + f"/{i}.mol2" for i in range(len(splitted_mols))]

# @njit(parallel=True)
def apply_pipeline(pathes:dict, pipeline):
	img_dict = {}
	for key, value in tqdm(pathes.items(), desc="Applying pipeline"):
		if len(key) == 1:
			train_paths, test_paths = value
			train_imgs = pipeline.transform(train_paths)
			test_imgs = pipeline.transform(test_paths)
			img_dict[key] = (train_imgs, test_imgs)
		else:
			assert key == "decoy"
			img_dict[key] = pipeline.transform(value)
	return img_dict

from sklearn.metrics import pairwise_distances
def img_distances(img_dict:dict):
	distances_to_anchors = []
	ytest = []
	decoy_list = img_dict["decoy"]
	for letter, imgs in img_dict.items():
		if len(letter) != 1 : continue # decoy
		xtrain, xtest = imgs
		assert len(xtest)>0
		train_data, test_data = xtrain, np.concatenate([xtest ,decoy_list])
		D = pairwise_distances(train_data, test_data)
		distances_to_anchors.append(D)
		letter_ytest = np.array([letter]*len(xtest) + ['0']*len(decoy_list), dtype="<U1")
		ytest.append(letter_ytest)
	return distances_to_anchors, ytest
	
def get_EF_vector_from_distances(distances, ytest, alpha=0.05):
	EF = []
	for distance_to_anchors, letter_ytest in zip(distances, ytest):
		indices = np.argsort(distance_to_anchors, axis=1)
		n = indices.shape[1]
		n_max = int(alpha*n)
		good_indices = (letter_ytest[indices[:,:n_max]] == letter_ytest[0]) ## assumes that ytest[:,0] are the good letters
		EF_letter = good_indices.sum(axis=1) / (letter_ytest == letter_ytest[0]).sum()
		EF_letter /= alpha
		EF.append(EF_letter.mean())
	return np.mean(EF)

def EF_from_distance_matrix(distances:np.ndarray, labels:list|np.ndarray, alpha:float, anchors_in_test=True):
	"""
	Computes the Enrichment Factor from a distance matrix, and its labels.
	 - First axis of the distance matrix is the anchors on which to compute the EF
	 - Second axis is the test. For convenience, anchors can be put in test, if the flag anchors_in_test is set to true.
	 - labels is a table of bools, representing the the labels of the test axis of the distance matrix.
	 - alpha : the EF alpha parameter.
	"""
	n = len(labels)
	n_max = int(alpha*n)
	indices = np.argsort(distances, axis=1)
	EF_ = [((labels[idx[:n_max]]).sum()-anchors_in_test)/(labels.sum()-anchors_in_test) for idx in indices]
	return np.mean(EF_)/alpha

def EF_AUC(distances:np.ndarray, labels:np.ndarray, anchors_in_test=0):
	if distances.ndim == 1:
		distances = distances[None,:]
	assert distances.ndim == 2
	indices = np.argsort(distances, axis=1)
	out = []
	for i in range(1,distances.size):
		proportion_of_good_indices = (labels[indices[:,:i]].sum(axis=1).mean() -anchors_in_test)/min(i,labels.sum() -anchors_in_test)
		out.append(proportion_of_good_indices)
	# print(out)
	return np.mean(out)


def theorical_max_EF(distances,labels, alpha):
	n = len(labels)
	n_max = int(alpha*n)
	num_true_labels = np.sum(labels == labels[0]) ## if labels are not True / False, assumes that the first one is a good one
	return min(n_max, num_true_labels)/alpha


def theorical_max_EF_from_distances(list_of_distances,list_of_labels, alpha):
	return np.mean([theorical_max_EF(distances, labels,alpha) for distances, labels in zip(list_of_distances, list_of_labels)])

def plot_EF_from_distances(alphas = [0.01, 0.02, 0.05, 0.1], EF = EF_from_distance_matrix, plot:bool=True):
	y = np.round([EF(alpha=alpha) for alpha in alphas], decimals=2)
	if plot:
		_alphas = np.linspace(0.01, 1., 100)
		plt.figure()
		plt.plot(_alphas, [EF(alpha=alpha) for alpha in _alphas])
		plt.scatter(alphas, y, c='r')
		plt.title("Enrichment Factor")
		plt.xlabel(r"$\alpha$" + f" = {alphas}")
		plt.ylabel(r"$\mathrm{EF}_\alpha$" + f" = {y}")
	return y

