#%%
from sklearn import metrics
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd

steps = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

# %%
base_dir = Path(__file__).parent.parent
from matplotlib import pyplot as plt
plt.rcParams.update({
    "text.usetex": True,  # Enables latex equations
    "font.family": "cmu-serif", # Sets the correct font
    "mathtext.fontset": "cm",   # --"--
    "font.size": 21,            # Set the font according to what you need
    "text.latex.preamble": r"\usepackage{amsmath}"   # You can add this to enable complicated math stuff
})
plt.style.use('tableau-colorblind10')  # You can use this to get a colorblind color palette


# Ablation
method_lst=[
    "predmap15_ablation_mean_attnmap_only11",
    "predmap15_ablation_weighted_attnmap11",
    "predmap15_ablation_predmap11",
    "predmap15_softmax_classes_batched_layer11",
    "predmap15_softmax_classes_batched_layer",
]


# Methods
method_lst=[
    "attn_last_layer",
    "rollout",
    "full_lrp",
    "lrp_last_layer",
    "attn_gradcam",
    "transformer_attribution",
    "predmap15_softmax_classes_batched_layer10",
    # "predmap15_softmax_classes_batched_layer11",
    "predmap15_softmax_classes_batched_layer",
]


def get_experiment_dir_path(base_path: Path):
    experiment_path_lst = base_path.rglob("**/experiment*")
    experiment_path_lst = [x for x in experiment_path_lst if not is_dir_empty(x)]
    if not experiment_path_lst:
        return None
    # experiment_path_lst = sorted(experiment_path_lst, key=lambda x: int(str(x).split('_')[-1]))    
    return experiment_path_lst[-1]

def is_dir_empty(path: Path):
    is_empty = not any(path.iterdir())
    return is_empty

# exp_dir = base_dir / "baselines/ViT/optimized/imagenet/vit-b/experiments/perturbations/"
exp_dir = base_dir / "baselines/ViT/optimized/imagenet/vit-l/experiments/perturbations/"

def get_perturb_results(method, pos_neg, target_top):
    tmp_dir = exp_dir / f"{method}_{pos_neg}/{target_top}/"
    experiment_dir = get_experiment_dir_path(tmp_dir)
    if not experiment_dir:
        print(f"Skipping {tmp_dir}")
        return None
    perturbations_hits_file = experiment_dir / "perturbations_hits.npy"
    model_hits_file = experiment_dir / "model_hits.npy"
    perturbations_hits = np.load(perturbations_hits_file)
    model_hits = np.load(model_hits_file)
    res = [model_hits.mean()] + perturbations_hits.mean(axis=-1).tolist()
    auc = metrics.auc(steps, res)
    auc = auc*100
    return auc



records = []
for method in method_lst:
    for pos_neg in ["pos", "neg"]:
        for target_top in ["target", "top"]:
            score = get_perturb_results(method, pos_neg, target_top)
            records.append({
                "method": method,
                "pos_neg": pos_neg,
                "target_top": target_top,
                "score": score,
            })
df = pd.DataFrame(records)
df = df.pivot_table(index="method", columns=["target_top", "pos_neg"], values="score", sort=False)

for row in df.iterrows():
    method = row[0]

    row_str = f'{method: <50} & {row[1]["top", "neg"]:.2f} & {row[1]["target", "neg"]:.2f} & {row[1]["top", "pos"]:.2f} & {row[1]["target", "pos"]:.2f} \\\\'
    print(row_str)
# %%

