

# dtrain = xgb.DMatrix('train.csv?format=csv&label_column=0')
# dtest = xgb.DMatrix('test.csv?format=csv&label_column=0')

# param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
# param['nthread'] = 4
# param['eval_metric'] = 'auc'


import pandas as pd
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import allel
import numpy as np
from sklearn.decomposition import PCA
from sklearn import tree
from pandas import *
from sklearn import tree
import seaborn as sns
#import graphviz
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.multioutput import MultiOutputClassifier

import xgboost as xgb
from xgboost import plot_tree

def transform(li):
    # print(np.sum(li))
    return np.sum(li)

vf = np.vectorize(transform)

def getData(file, labels, run_name="clf"):
    callset = allel.read_vcf(file)
    data = allel.GenotypeArray((callset['calldata/GT']))
    snp_patients_temp = callset['samples']

    snp_patients = []
    for patient in snp_patients_temp:
        snp_patients.append(int(patient[-4:]))

    data_transformed = np.zeros_like(a=data, shape=(data.shape[0], data.shape[1]))

    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            data_transformed[i][j] = transform(data[i][j])


    patientLabels = read_csv(labels)
    patients = patientLabels["Patient"]
    patients = patients.values.tolist()

    subtypeLabel_temp = patientLabels["Subtype"]
    subtypeLabel_temp = subtypeLabel_temp.values.tolist()
    data_transformed = data_transformed.T

 

    #print(subtypeLabel_temp)
    patientsIdx = []
    subtypeLabel = []
    for i in range(len(snp_patients)):
        if snp_patients[i] in patients:
            patientsIdx.append(i)
            subtypeLabel.append(subtypeLabel_temp[patients.index(snp_patients[i])])


    data_transformed = data_transformed[patientsIdx, :]
    subtypeLabel = np.asarray(subtypeLabel)

    print(type(data_transformed))
    print(data_transformed.shape)
    feature_names=callset['variants/ID']

    # with open("feature_map.txt", "w") as f:
    #     for i in range(len(feature_names)):
    #         f.write(str(i)+"\t"+feature_names[i]+"\t"+"q"+"\n")

    #data_transformed = pd.DataFrame(data_transformed, columns=feature_names) 

    # print(data_transformed.shape)
    # trainlen = int(0.8 * data_transformed.shape[0])
    # dtrain = xgb.DMatrix(data_transformed[:trainlen], label=subtypeLabel[:trainlen])
    # dtest = xgb.DMatrix(data_transformed[trainlen:], label=subtypeLabel[trainlen:])



    #Split for train-validation (80-20)
    x_train, x_test, y_train, y_test = train_test_split(data_transformed, subtypeLabel, test_size=0.20, random_state=42, stratify = subtypeLabel)
    print("X train shape")
    print(x_train.shape)
    print("X test shape")
    print(x_test.shape)
    print("Y train shape")
    print(y_train.shape)
    print("Y test shape")
    print(y_test.shape)
    print("Train Dataset Distribution")
    print(np.unique(y_train, return_counts = True))
    print("Test Dataset Distribution")
    print(np.unique(y_test, return_counts = True))


    classes = 5
    y_train_oh = np.zeros((len(y_train), classes))
    y_test_oh = np.zeros((len(y_test), classes))
    for i in range(len(y_train)):
        y_train_oh[i][y_train[i]] = 1

    for i in range(len(y_test)):
        y_test_oh[i][y_test[i]] = 1  

    # control = 0
    # for i in range(len(y_train)):
    #     if y_train[i] == control:
    #         y_train[i] = 0
    #     else:
    #         y_train[i] = 1
    # for i in range(len(y_test)):
    #     if y_test[i] == control:
    #         y_test[i] = 0
    #     else:
    #         y_test[i] = 1

    #Get indices of everything with each class
    # label_0 =np.argwhere(y_train == 0)
    # np.random.shuffle(label_0)
    # label_0 = [item for sublist in label_0[:8] for item in sublist]

    # label_1 = np.argwhere(y_train == 1)
    # np.random.shuffle(label_1)
    # label_1 = [item for sublist in label_1[:8] for item in sublist]
    # print(label_1)

    # label_2 = np.argwhere(y_train == 2)
    # np.random.shuffle(label_2)
    # label_2 = [item for sublist in label_2[:8] for item in sublist]

    # label_3 = np.argwhere(y_train == 3)
    # np.random.shuffle(label_3)
    # label_3 = [item for sublist in label_3[:8] for item in sublist]

    # label_4 = np.argwhere(y_train == 4)
    # np.random.shuffle(label_4)
    # label_4 = [item for sublist in label_4[:8] for item in sublist]

    # #Create equal class training data
    # l = label_0 + label_1 + label_2 + label_3 + label_4
    # x_eq_train = np.take(x_train, l, 0)
    # y_eq_train = np.take(y_train_oh, l, 0)
    # print("X equal train dimensions")
    # print(x_eq_train.shape)
    # print("Y equal train dimensions")
    # print(y_eq_train.shape)

    # #Concatenate remaining as test data
    # x_eq_test = np.concatenate([np.delete(x_train, l, axis=0), x_test], axis = 0)
    # y_eq_test = np.concatenate([np.delete(y_train_oh, l, axis=0), y_test_oh], axis = 0)
    # print("X equal test dimensions")
    # print(x_eq_test.shape)
    # print("Y equal test dimensions")
    # print(y_eq_test.shape)

    #Fit decsion tree on equal class dataset
    # clf = tree.DecisionTreeClassifier(random_state = 42)
    # clf = clf.fit(x_train, y_train)

    # #Make predictions on test data
    # y_pred = clf.predict(x_test)

    # #Determine test accuracy
    # print("Leftover Test Accuracy:")
    # print(accuracy_score(y_test, y_pred))

    # #Visualize confusion matrix
    # print("Confusion Matrix")
    # print(confusion_matrix(y_test, y_pred))

    print("-------XGBoost-------")
    #print(callset.keys())
    #print(type(callset['variants/ID']))
    # print(callset['variants/ID'])

    # dtrain = xgb.DMatrix(x_train, label=y_train, feature_names=feature_names)
    # dtest = xgb.DMatrix(x_test, label=y_test)

    # param = {'max_depth': 8, 'eta': 1, 'objective': 'binary:logistic'}
    # param['nthread'] = 4
    # param['eval_metric'] = 'auc'
    # evallist = [(dtest, 'eval'), (dtrain, 'train')]
    # num_round = 20


    # print(len(feature_names))
    # print(x_train.shape, x_test.shape)

    # newx_train = dict()
    # newx_test = dict()
    # for idx, name in enumerate(feature_names):
    #     newx_train[name] = x_train[:,idx]
    #     newx_test[name] = x_test[:, idx]

    # #print()
        
    # bst = xgb.train(param, dtrain, num_round, evallist)
    # bst.save_model('0001.model')
    # y_pred = bst.predict(dtest)
    # y_pred = y_pred > 0.5
    # dummy = np.zeros((len(y_pred),))

    # print("Leftover Test Accuracy:")
    # print(y_test, y_pred)
    # print(accuracy_score(y_test, y_pred))
    # print("dummy", accuracy_score(y_test, dummy))

    # #Visualize confusion matrix
    # print("Confusion Matrix")
    # print(confusion_matrix(y_test, y_pred))
    # print(dummy, confusion_matrix(y_test, dummy))
    # from xgboost import plot_tree
    # import matplotlib.pyplot as plt
    # plot_tree(bst)
    # plt.save("xgboost_tree")
    # plt.show()

    #-----worked
    # from xgboost import XGBClassifier
    # from xgboost import plot_tree
    # import matplotlib.pyplot as plt
    # model = XGBClassifier()
    # model.fit(x_train, y_train)
    # plot_tree(model,fmap="feature_map.txt")
    # plt.savefig("xgboost_graph", dpi=500)
    # plt.show()


    # X = x_train + x_test
    # Y = y_train + y_test

    # kfold = StratifiedKFold(n_splits=10, random_state=7)
    # results = cross_val_score(model, X, Y, cv=kfold)
    # print("Accuracy: %.2f%% (%.2f%%)" % (results.mean()*100, results.std()*100))

    # return None
    # -------
    print("binary logistic multilabel model")
    # xgb_estimator = xgb.XGBClassifier(objective='binary:logistic')
    # multilabel_model = MultiOutputClassifier(xgb_estimator)
    # #print(y_eq_train)
    # multilabel_model.fit(x_train, y_train_oh)

    # y_pred = multilabel_model.predict(x_test)

    # print(y_pred)
    # print("Leftover Test Accuracy:")
    # print(accuracy_score(y_test_oh, y_pred))

    model = MultiOutputClassifier(
        xgb.XGBClassifier(objective="binary:logistic",
                         colsample_bytree = 0.5,
                          gamma = 0.1
                         ))

    #Define a pipeline
    from sklearn.pipeline import Pipeline
    pipeline = Pipeline([("XGB", model)])

    pipeline.fit(x_train, y_train_oh)

    predicted = pipeline.predict(x_test)

    for i in range(5):
        xgb.plot_tree(pipeline.named_steps["XGB"].estimators_[i], num_trees=4, fmap="feature_map.txt")
        plt.savefig("multilabel_xgboost"+str(i), dpi=500)
        plt.show()
    return None 
    #Visualize confusion matrix
    #print("Confusion Matrix")
    #print(confusion_matrix(y_test_oh, y_pred))


    #plot_tree(multilabel_model)
    #plt.show()

    # ------
    xgb_estimator = xgb.XGBClassifier(objective='binary:logistic')
    multilabel_model = MultiOutputClassifier(xgb_estimator)
    print(y_eq_train)
    multilabel_model.fit(x_eq_train, y_eq_train)

    y_pred = multilabel_model.predict(x_eq_test)

    print(y_pred)
    print("Leftover Test Accuracy:")
    print(accuracy_score(y_eq_test, y_pred))

    #Visualize confusion matrix
    print("Confusion Matrix")
    cfm = confusion_matrix(y_eq_test, y_pred)
    print(cfm)
    sns.heatmap(cfm, annot=True)
    plt.savefig(run_name + "_heatmap_eq.png")
    plt.clf()

    # ------
    xgb_estimator = xgb.XGBClassifier(objective='binary:logistic')
    multilabel_model = MultiOutputClassifier(xgb_estimator)
    print(y_train_oh)
    multilabel_model.fit(x_train, y_train_oh)

    y_pred = multilabel_model.predict(x_test)

    print("????")
    print(y_pred)
    print("Leftover Test Accuracy:")
    print(accuracy_score(y_test_oh, y_pred))

    #Visualize confusion matrix
    print("Confusion Matrix")
    cfm = confusion_matrix(y_test_oh, y_pred)
    print(cfm)
    sns.heatmap(cfm, annot=True)
    plt.savefig(run_name + "_heatmap_oh.png")
    plt.clf()
    #maxDepth = 10
    # If you wish to do max depth then use this line in stead
    #dot_data = tree.export_graphviz(clf, out_file=None, feature_names=callset['variants/ID'], max_depth=maxDepth)
    # dot_data = tree.export_graphviz(clf, out_file=None, feature_names=callset['variants/ID'])
    #graph = graphviz.Source(dot_data)
    #graph.format = "png"
    # graph.render("decisiontree")
    #graph.render("withref", view=True)



def main():
    getData("new_vcf.vcf", "patientLabels.csv", run_name="XGBoost_noncoding")
    # parseCall("Call(sample=067_S_0056, CallData(GT=0/1))")
    # data = pd.read_csv("/Users/hefeitu/Desktop/ProjectX/SNP.csv")
    # print(data.shape)
    # R = data.shape[0]
    # C = data.shape[1]
    # for i in range(R):
    #   cur_row = data["samples"][i]


if __name__ == "__main__":
    main()