import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import allel
import numpy as np
from sklearn.decomposition import PCA
from sklearn import tree
from pandas import *
from sklearn import tree
#import graphviz
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
import statistics

def transform(li):
	return np.sum(li)

vf = np.vectorize(transform)

def getData(file, labels):
	callset = allel.read_vcf(file)
	data = allel.GenotypeArray((callset['calldata/GT']))
	snp_patients_temp = callset['samples']
	snp_patients = []
	for patient in snp_patients_temp:
		snp_patients.append(int(patient[-4:]))

	data_transformed = np.zeros_like(a=data, shape=(data.shape[0], data.shape[1]))

	for i in range(data.shape[0]):
		for j in range(data.shape[1]):
			data_transformed[i][j] = transform(data[i][j])

	patientLabels = read_csv(labels)
	patients = patientLabels["Patient"]
	patients = patients.values.tolist()

	subtypeLabel_temp = patientLabels["Subtype"]
	subtypeLabel_temp = subtypeLabel_temp.values.tolist()
	data_transformed = data_transformed.T

	patientsIdx = []
	subtypeLabel = []
	for i in range(len(snp_patients)):
		if snp_patients[i] in patients:
			patientsIdx.append(i)
			subtypeLabel.append(subtypeLabel_temp[patients.index(snp_patients[i])])

	data_transformed = data_transformed[patientsIdx, :]
	subtypeLabel = np.asarray(subtypeLabel)

	#Split for train-validation (80-20)
	x_train, x_test, y_train, y_test = train_test_split(data_transformed, subtypeLabel, test_size=0.20, random_state=42)
	print(x_train.shape)
	print(x_test.shape)
	print(y_train.shape)
	print(y_test.shape)

	#Make predictions on test data
	y_pred = clf.predict(x_test)
	y_pred_train = clf.predict(x_train)

	#Determine train accuracy
	print("Train Accuracy:")
	print(accuracy_score(y_train, y_pred_train))

	#Determine test accuracy
	print("Test Accuracy:")
	print(accuracy_score(y_test, y_pred))

	#Visualize confusion matrix
	print("Confusion Matrix")
	cfm = confusion_matrix(y_test, y_pred)
	print(cfm)
	sns.heatmap(cfm, annot=True)
	plt.savefig("Heatmap.png")
	plt.clf()

	#precision score
	print("Precision Score")
	print(precision_score(y_test, y_pred, average='weighted'))

	scores = cross_val_score(clf, data_transformed, subtypeLabel, cv=5, scoring = 'accuracy')
	print("Cross-Val Scores:")
	print(scores)
	print("\n")
	print("Mean:")
	print(sum(scores)/5)
	print("\n")
	print("Std Dev:")
	print(statistics.stdev(scores))
	print("\n")


def main():
	getData("codingsnps.vcf", "patientLabels.csv")


if __name__ == "__main__":
	main()