import pandas as pd
import glob
from sklearn.metrics import precision_score, recall_score, accuracy_score
import numpy as np
import os

model_name = "gpt-3.5-turbo"
print(model_name)

for turn in ["5_turns", "all_turns"]:
    for role in ["ask", "answer"]:
        file_path = f"./results/classification/{model_name}/{turn}/{role}"
        csv_files = glob.glob(file_path + "/*.csv")
        csv_files = [file for file in csv_files if not os.path.basename(file).startswith("summary")]
        dataframes = [pd.read_csv(file) for file in csv_files]
        print("*"*10)
        print(turn, role)
        
        label = {
            "Child": 0,
            "Teen": 1,
            "College Student": 2,
            "Grad Student": 3,
            "Expert": 4
        }

        y_true = []
        y_pred = []

        acc = 0
        for df in dataframes:
            for index, row in df.iterrows():
                y_true.append(label[row['ground_truth']])
                y_pred.append(label[row['answer']])
                if label[row['ground_truth']] == label[row['answer']]:
                    acc+=1

        print(len(y_true), len(y_pred), acc/len(y_pred))
        # if role == "ask" and turn == "all_turns":
        #     import pdb;pdb.set_trace()
        # Calculate precision and recall for each class
        precision = precision_score(y_true, y_pred, average=None)
        recall = recall_score(y_true, y_pred, average=None)

        print("Precision per class:", precision)
        print("Recall per class:", recall)

        precision_macro = precision_score(y_true, y_pred, average='macro')
        recall_macro = recall_score(y_true, y_pred, average='macro')

        precision_micro = precision_score(y_true, y_pred, average='micro')
        recall_micro = recall_score(y_true, y_pred, average='micro')

        print("\nMacro Averaged Precision:", precision_macro)
        print("Macro Averaged Recall:", recall_macro)

        print("\nMicro Averaged Precision:", precision_micro)
        print("Micro Averaged Recall:", recall_micro)

        accuracy = accuracy_score(y_true, y_pred)
        print("\nAccuracy:", accuracy)
