import pandas as pd
import numpy as np
import csv
from collections import defaultdict

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import precision_score, recall_score, \
    roc_auc_score, accuracy_score, f1_score, average_precision_score

def get_ec_pos_dict(mlb, true_label, pred_label):
    ec_list = []
    pos_list = []
    for i in range(len(true_label)):
        ec_list += list(mlb.inverse_transform(mlb.transform([true_label[i]]))[0])
        pos_list += list(np.nonzero(mlb.transform([true_label[i]]))[1])
    for i in range(len(pred_label)):
        ec_list += list(mlb.inverse_transform(mlb.transform([pred_label[i]]))[0])
        pos_list += list(np.nonzero(mlb.transform([pred_label[i]]))[1])
    label_pos_dict = {}
    for i in range(len(ec_list)):
        ec, pos = ec_list[i], pos_list[i]
        label_pos_dict[ec] = pos
        
    return label_pos_dict

def get_eval_metrics(pred_label, pred_probs, true_label, all_label):
    mlb = MultiLabelBinarizer()
    mlb.fit([list(all_label)])
    n_test = len(pred_label)
    pred_m = np.zeros((n_test, len(mlb.classes_)))
    true_m = np.zeros((n_test, len(mlb.classes_)))
    # for including probability
    pred_m_auc = np.zeros((n_test, len(mlb.classes_)))
    label_pos_dict = get_ec_pos_dict(mlb, true_label, pred_label)
    for i in range(n_test):
        pred_m[i] = mlb.transform([pred_label[i]])
        true_m[i] = mlb.transform([true_label[i]])
         # fill in probabilities for prediction
        labels, probs = pred_label[i], pred_probs[i]
        for label, prob in zip(labels, probs):
            if label in all_label:
                pos = label_pos_dict[label]
                pred_m_auc[i, pos] = prob
    pre = precision_score(true_m, pred_m, average='weighted', zero_division=0)
    rec = recall_score(true_m, pred_m, average='weighted')
    f1 = f1_score(true_m, pred_m, average='weighted')
    roc = roc_auc_score(true_m, pred_m_auc, average='weighted')
    acc = accuracy_score(true_m, pred_m)
    return pre, rec, f1, roc, acc

def get_true_labels(file_name):
    result = open(file_name+'.csv', 'r')
    csvreader = csv.reader(result, delimiter='\t')
    all_label = set()
    true_label_dict = {}
    header = True
    count = 0
    for row in csvreader:
        # don't read the header
        if header is False:
            count += 1
            true_ec_lst = row[1].split(';')
            true_label_dict[row[0]] = true_ec_lst
            for ec in true_ec_lst:
                all_label.add(ec)
        if header:
            header = False
    true_label = [true_label_dict[i] for i in true_label_dict.keys()]
    return true_label, all_label


results = pd.read_csv("new_split30.tsv", sep='\t')
print(results)
results_dict = defaultdict(list)
results_prob = defaultdict(list)
annotations = pd.read_csv("data/split30.csv", sep='\t')
print(annotations)
annotation = {}
for i in range(len(annotations)):
    key = annotations['Entry'].iloc[i]
    ec = annotations['EC number'].iloc[i]
    annotation[key] = ec
for i in range(len(results)):
    label = results.iloc[i]['Hit accession'].split("|")[1]
    expect = results.iloc[i]['Expect']
    
    anno = annotation[label].split(";")
    results_dict[results.iloc[i]['Query']].extend(anno)
    results_prob[results.iloc[i]['Query']].extend([expect] * len(anno))

for key in results_prob:
    results_prob[key] = np.array(results_prob[key]) 
    results_prob[key] = results_prob[key] / (results_prob[key].sum() + 1e-9)

true_label, all_label = get_true_labels('./data/new')
d = pd.read_csv("./data/new.csv", sep='\t')
entries = d['Entry']
print(results_prob)
print(results_dict)
pre, rec, f1, roc, acc = get_eval_metrics(
            [results_dict[key] for key in entries], [results_prob[key] for key in entries], true_label, all_label)
print(rec)
print(pre)
print(f1)
print(acc)