import pandas as pd
from tqdm import tqdm
import numpy as np
import argparse

from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import f1_score, roc_auc_score

parser = argparse.ArgumentParser()
parser.add_argument('--version', type=str)
args = parser.parse_args()

def argmax(lst):
    return lst.index(max(lst))

def secmax(lst):
    return np.argsort(np.array(lst))[-2]


def eval_third_agent(pred_df, gt_label, weight):
    
    index = [
        [0,1,2,3,4,5,6], #0
        [0,7,8,9,10,11,12], #1
        [1,7,13,14,15,16,17], #2
        [2,8,13,18,19,20,21], #3
        [3,9,14,18,22,23,24], #4
        [4,10,15,19,22,25,26], #5
        [5,11,16,20,23,25,27], #6
        [6,12,17,21,24,26,27] #7
    ]

    pred_label = []
    for sample_id in tqdm(range(len(gt_label)), total=len(gt_label)):
        scores = []
        temp_scores = []
        for i in range(8):
            score = pred_df['%d'%sample_id][index[i]]

            mistake = score.tolist().count(2)
            score = np.array(score)

            ii = np.not_equal(score, 2)
            score[i:] = abs(score[i:]-1)
            scores.append((score*ii).sum() - mistake * weight)

        pred_label.append(argmax(scores))
    return pred_label


def get_mean_std(arr):
    a_mean = np.mean(arr)
    a_std = np.std(arr, ddof=1)
    return a_mean, a_std

def roc_auc_score_multiclass(actual_class, pred_class, average = "macro"):
    
    #creating a set of all the unique classes using the actual class list
    unique_class = set(actual_class)
    roc_auc_dict = []
    for per_class in unique_class:
        
        #creating a list of all the classes except the current class 
        other_class = [x for x in unique_class if x != per_class]

        #marking the current class as 1 and all other classes as 0
        new_actual_class = [0 if x in other_class else 1 for x in actual_class]
        new_pred_class = [0 if x in other_class else 1 for x in pred_class]

        #using the sklearn metrics method to calculate the roc_auc_score
        roc_auc = roc_auc_score(new_actual_class, new_pred_class, average = average)
        roc_auc_dict.append(roc_auc)

    # return sum(roc_auc_dict)/len(roc_auc_dict)
    return roc_auc_dict


if __name__ == "__main__":

    accs = []
    f1_macs = []
    weight = []
    
    for k_fold in range(5):

        pred_df = pd.read_csv("result/%s_agent/%d_fold_predicted_label_epoch_1500.csv" % (args.version, k_fold))
        gt_df = pd.read_csv("result/%s_agent/%d_fold_true_label_epoch_1500.csv" % (args.version, k_fold))

        gt_label = gt_df.squeeze().values.tolist()

        pred_label = eval_third_agent(pred_df, gt_label, weight)
        pred_label, gt_label = np.array(pred_label), np.array(gt_label)

        acc = accuracy_score(pred_label, gt_label)
        f1_mac = f1_score(pred_label, gt_label, average='macro')
        accs.append(acc)
        f1_macs.append(f1_mac)

    print("%.3f ± %.3f" % (get_mean_std(accs)[0], get_mean_std(accs)[1]))
    print("%.3f ± %.3f" % (get_mean_std(f1_macs)[0], get_mean_std(f1_macs)[1]))

