import torch
import os
import pandas as pd

dir_path = "../ml-aura/auroc_results/"

model = "Llama-3.1-8B"

# emotions = [i for i in os.listdir(dir_path) if ".py" not in i]
# for emotion in emotions:
#     emotion_layers = sorted(os.listdir(dir_path + emotion + "/Qwen2.5-7B-Instruct"), key=lambda x: int(x.split(".")[2]))
#     for i in range(len(emotion_layers)):
#         auroc_results = torch.load(dir_path + emotion + "/Qwen2.5-7B-Instruct/" + emotion_layers[i])
#         print(f"{emotion.split('-')[1].capitalize()}: "
#               f"Layer #{emotion_layers[i].split('.')[2]} - {emotion_layers[i].split('.')[4]}: "
#               f"{torch.sum(auroc_results['alpha'] > 0.9) / auroc_results['alpha'].numel()}")
# import IPython; IPython.embed()


# Custom sublayer order
order = {k: i for i, k in enumerate(['q_proj', 'k_proj', 'v_proj', 'o_proj', 'up_proj', 'gate_proj', 'down_proj'])}

emotions = [i for i in os.listdir(dir_path) if ".py" not in i]
data, layers = {}, []

for emotion in emotions:
    files = sorted(os.listdir(f"{dir_path}/{emotion}/{model}"),
                   key=lambda x: (int(x.split('.')[2]), order.get(x.split('.')[4], 999)))

    if not layers:
        layers = [(int(f.split('.')[2]), f.split('.')[4]) for f in files]

    data[emotion.split('-')[1].capitalize()] = [
        torch.sum((res := torch.load(f"{dir_path}/{emotion}/{model}/{file}"))['alpha'] > 0.9).item()
        / res['alpha'].numel()
        for file in files
    ]

# Assign MultiIndex correctly
df = pd.DataFrame(data)
df.index = pd.MultiIndex.from_tuples(layers, names=["Layer", "Sublayer"])

# Sort using the custom order
df = df.sort_index(level=["Layer", "Sublayer"], key=lambda x: x.map(order) if x.name == 'Sublayer' else x)

df.to_csv("auroc_results_classification_llama3.1_base.csv")

# import IPython; IPython.embed()