import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def load_confusion_matrices(root_dir, keyword):
    matrices = []
    for subdir, dirs, files in os.walk(root_dir):
        for file in files:
            if keyword in file and file.startswith('confusion_matrix') and file.endswith('.csv'):
                file_path = os.path.join(subdir, file)
                df = pd.read_csv(file_path, index_col=0)
                matrix = df.values
                matrices.append(matrix)
    return matrices

def calculate_f1_scores(matrices):
    macro_f1_scores = []
    micro_f1_scores = []
    for matrix in matrices:
        # Calculate F1 Scores
        true_positives = np.diag(matrix)
        false_positives = np.sum(matrix, axis=0) - true_positives
        false_negatives = np.sum(matrix, axis=1) - true_positives

        epsilon = 1e-5  # A small number to prevent division by zero
        precision = np.nan_to_num(true_positives / (true_positives + false_positives + epsilon))
        recall = np.nan_to_num(true_positives / (true_positives + false_negatives + epsilon))
        f1_scores = 2 * (precision * recall) / (precision + recall + epsilon)

        macro_f1 = np.nanmean(f1_scores)
        micro_total_true_positives = np.sum(true_positives)
        micro_total_false_positives = np.sum(false_positives)
        micro_total_false_negatives = np.sum(false_negatives)
        micro_precision = micro_total_true_positives / (micro_total_true_positives + micro_total_false_positives + epsilon)
        micro_recall = micro_total_true_positives / (micro_total_true_positives + micro_total_false_negatives + epsilon)
        micro_f1 = 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall + epsilon)

        macro_f1_scores.append(macro_f1)
        micro_f1_scores.append(micro_f1)

    return macro_f1_scores, micro_f1_scores

def plot_confusion_matrix(matrix, labels, root_dir, embedding, dataset):
    plt.imshow(matrix, interpolation='nearest', cmap=plt.cm.Blues)
    cbar=plt.colorbar()
    cbar.ax.tick_params(labelsize=16)
    tick_marks = np.arange(len(labels))
    plt.xticks(tick_marks, labels, fontsize=16, rotation=15)
    plt.yticks(tick_marks, labels, fontsize=16)

    for i in range(len(labels)):
        for j in range(len(labels)):
            plt.text(j, i, format(matrix[i, j], '.2f'),
                     horizontalalignment="center",
                     color="white" if matrix[i, j] > matrix.max() / 2. else "black", fontsize=20)

    plt.savefig(os.path.join(root_dir, "final_normalized_mean_conf_matrix"+embedding+".png"))


if __name__ == '__main__':

    root_dir = 'SEED_result/PLL_confusion/main/PGNA_PL/scheduler_True/optimizer_sgd/lr_0.01/confidence_False/beta_parameter_3.0'  # Change this to your directory
    semantic_matrices = load_confusion_matrices(root_dir, 'Semantic')
    russel_matrices = load_confusion_matrices(root_dir, 'Russel')

    semantic_macro_f1, semantic_micro_f1 = calculate_f1_scores(semantic_matrices)
    russel_macro_f1, russel_micro_f1 = calculate_f1_scores(russel_matrices)

    # Printing the results
    print("Semantic Macro F1 mean/std:",
          str(round(np.mean(semantic_macro_f1) * 100, 2)) + "(" + str(round(np.std(semantic_macro_f1) * 100, 2)) + ")")
    print("Semantic Micro F1 mean/std:",
          str(round(np.mean(semantic_micro_f1) * 100, 2)) + "(" + str(round(np.std(semantic_micro_f1) * 100, 2)) + ")")
    print("Russel Macro F1 mean/std:",
          str(round(np.mean(russel_macro_f1) * 100, 2)) + "(" + str(round(np.std(russel_macro_f1) * 100, 2)) + ")")
    print("Russel Micro F1 mean/std:",
          str(round(np.mean(russel_micro_f1) * 100, 2)) + "(" + str(round(np.std(russel_micro_f1) * 100, 2)) + ")")

    # Optionally plot the mean confusion matrix for visualization
    # You can modify plot_confusion_matrix function for larger text
    mean_semantic_matrix = np.mean(semantic_matrices, axis=0)
    mean_russel_matrix = np.mean(russel_matrices, axis=0)

    #draw a graph for the confusion matrix
    if "SEED_IV_result" in root_dir:
        labels =['Neutral', 'Sad', 'Fear', 'Happy']
    elif "SEED_V_result" in root_dir:
        labels = ['Disgust', 'Fear', 'Sad', 'Neutral', 'Happy']
    else:
        labels = ['Negative', 'Neutral', 'Positive']

    if "SEED_V" in root_dir:
        dataset = "SEED-V"
    elif "SEED_IV" in root_dir:
        dataset = "SEED-IV"
    else:
        dataset = "SEED"
    embedding ="_Semantic"
    plot_confusion_matrix(mean_semantic_matrix, labels, root_dir, embedding, dataset)
    # embedding = "_Russel"
    # plot_confusion_matrix(mean_russel_matrix, labels, root_dir, embedding, dataset)

