import pandas as pd
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
import allel
import numpy as np
from sklearn.decomposition import PCA
from sklearn import tree
from sklearn import ensemble
from sklearn import tree
from sklearn.model_selection import cross_validate
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score

'''
# Commands to generate vcf files from txt snp lists, given WGS_Omni2.5M_20140220 from ADNI and plink program:
./plink --bfile WGS_Omni25_BIN_wo_ConsentsIssues --extract SNP_selections/wholeattempt.txt --recode vcf-iid --out ./wholeattempt
./plink --bfile WGS_Omni25_BIN_wo_ConsentsIssues --extract SNP_selections/nballsnps.txt --recode vcf-iid --out ./nballsnps
./plink --bfile WGS_Omni25_BIN_wo_ConsentsIssues --extract SNP_selections/codingsnps.txt --recode vcf-iid --out ./codingsnps
./plink --bfile WGS_Omni25_BIN_wo_ConsentsIssues --extract SNP_selections/chr19nb.txt --recode vcf-iid --out ./chr19nb
./plink --bfile WGS_Omni25_BIN_wo_ConsentsIssues --extract SNP_selections/chr19coding.txt --recode vcf-iid --out ./chr19coding
'''

def getData(file, labels, run_name="random_forest"):
	callset = allel.read_vcf(file)
	data = allel.GenotypeArray((callset['calldata/GT']))
	snp_patients_temp = callset['samples']
	snp_patients = []
	for patient in snp_patients_temp:
		snp_patients.append(int(patient[-4:]))

	print ("Input data (from snp vcf) shape:")
	print(data.shape)

	data_transformed = np.sum(data, axis=2)

	print ("Transformed data shape:")
	print(data_transformed.shape)

	patientLabels = pd.read_csv(labels)
	patients = patientLabels["Patient"]
	patients = patients.values.tolist()

	subtypeLabel_temp = patientLabels["Subtype"]
	subtypeLabel_temp = subtypeLabel_temp.values.tolist()
	data_transformed = data_transformed.T

	patientsIdx = []
	subtypeLabel = []
	for i in range(len(snp_patients)):
		if snp_patients[i] in patients:
			patientsIdx.append(i)
			subtypeLabel.append(subtypeLabel_temp[patients.index(snp_patients[i])])

	data_transformed = data_transformed[patientsIdx, :]
	subtypeLabel = np.asarray(subtypeLabel)
	
	print("\n")
	print("Starting cross-validation.")

	#5-fold cross validation
	clf = ensemble.RandomForestClassifier(random_state=5555, n_estimators=100) # note default n_estimators is 100
	scores = cross_validate(clf, data_transformed, subtypeLabel, cv=5, scoring="accuracy", return_train_score=True)
	print("Train scores (accuracy):")
	print(scores["train_score"])
	print("Test scores (accuracy):")
	print(scores["test_score"])

	print ("Done.")

	
	print ("\n")
	print ("Starting stratified stuff.")

	# Split for train-validation (80-20)
	x_train, x_test, y_train, y_test = train_test_split(data_transformed, subtypeLabel, test_size=0.20, random_state=5, stratify = subtypeLabel)
	print("X train shape")
	print(x_train.shape)
	print("X test shape")
	print(x_test.shape)
	print("Y train shape")
	print(y_train.shape)
	print("Y test shape")
	print(y_test.shape)

	#Check stratify
	print("Train Dataset Distribution")
	print(np.unique(y_train, return_counts = True))
	print("Test Dataset Distribution")
	print(np.unique(y_test, return_counts = True))

	#Fit on training data
	clf = ensemble.RandomForestClassifier(random_state=5555, n_estimators=100) # note default n_estimators is 100
	clf = clf.fit(x_train, y_train)

	#Make predictions on training data
	y_pred = clf.predict(x_train)

	#Determine test accuracy
	print("Train Accuracy:")
	print(accuracy_score(y_train, y_pred))

	#Make predictions on test data
	y_pred = clf.predict(x_test)

	#Determine test accuracy
	print("Test Accuracy:")
	print(accuracy_score(y_test, y_pred))

	#Visualize confusion matrix
	print("Confusion Matrix")
	cfm = confusion_matrix(y_test, y_pred)
	print(cfm)

	sns.heatmap(cfm, annot=True)
	plt.savefig(run_name + "_heatmap.png")
	plt.clf()

	print ("")
	print ("Done.")


	print ("\n")
	print ("Starting non-control-only (alzheimers only) stratified stuff.")

	print("Original X shape:")
	print(data_transformed.shape)
	print("Original y shape:")
	print(subtypeLabel.shape)

	alz_idxs = (subtypeLabel > 0)
	subtypeLabel_alz_only = subtypeLabel[alz_idxs]
	data_transformed_alz_only = data_transformed[alz_idxs, :]

	print("alz-only X shape:")
	print(data_transformed_alz_only.shape)
	print("alz-only y shape:")
	print(subtypeLabel_alz_only.shape)

	# Split for train-validation (80-20)
	x_train, x_test, y_train, y_test = train_test_split(data_transformed_alz_only, subtypeLabel_alz_only, test_size=0.20, random_state=5, stratify = subtypeLabel_alz_only)
	print("X train shape")
	print(x_train.shape)
	print("X test shape")
	print(x_test.shape)
	print("Y train shape")
	print(y_train.shape)
	print("Y test shape")
	print(y_test.shape)

	#Check stratify
	print("Train Dataset Distribution")
	print(np.unique(y_train, return_counts = True))
	print("Test Dataset Distribution")
	print(np.unique(y_test, return_counts = True))

	#Fit on training data
	clf = ensemble.RandomForestClassifier(random_state=5555, n_estimators=100) # note default n_estimators is 100
	clf = clf.fit(x_train, y_train)

	#Make predictions on training data
	y_pred = clf.predict(x_train)

	#Determine test accuracy
	print("Train Accuracy:")
	print(accuracy_score(y_train, y_pred))

	#Make predictions on test data
	y_pred = clf.predict(x_test)

	#Determine test accuracy
	print("Test Accuracy:")
	print(accuracy_score(y_test, y_pred))

	#Visualize confusion matrix
	print("Confusion Matrix")
	cfm = confusion_matrix(y_test, y_pred)
	print(cfm)

	sns.heatmap(cfm, annot=True)
	plt.savefig(run_name + "_heatmap_no_control.png")
	plt.clf()

	print ("")
	print ("Done.")


def main():
	snp_file_base_names = ["wholeattempt", "codingsnps"]
	print("")
	for snp_file_base_name in snp_file_base_names:
		snp_vcf_file = "SNP_selections_vcf/{}.vcf".format(snp_file_base_name)
		print("--------------------")
		print("")
		print("STARTING RANDOM FOREST RUN FOR SNPS FROM: " + snp_vcf_file)
		print ("")
		run_name = snp_file_base_name + "_random_forest"
		getData(snp_vcf_file, "patientLabels.csv", run_name=run_name)
		print("")
		print("DONE WITH RANDOM FOREST RUN FOR SNPS FROM: " + snp_vcf_file)
		print("")
		print("--------------------")
		print("")


if __name__ == "__main__":
	main()
