#%%
from sklearn import metrics
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from svglib.svglib import svg2rlg
from reportlab.graphics import renderPDF

steps = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

# %%
base_dir = Path(__file__).parent.parent
from matplotlib import pyplot as plt
plt.rcParams.update({
    "text.usetex": True,  # Enables latex equations
    "font.family": "cmu-serif", # Sets the correct font
    "mathtext.fontset": "cm",   # --"--
    "font.size": 21,            # Set the font according to what you need
    "text.latex.preamble": r"\usepackage{amsmath}"   # You can add this to enable complicated math stuff
})
plt.style.use('tableau-colorblind10')  # You can use this to get a colorblind color palette

method="predmap31_softmax_classes"
auc_lst = []


def get_experiment_dir_path(base_path: Path):
    experiment_path_lst = base_path.rglob("**/experiment*")
    experiment_path_lst = [x for x in experiment_path_lst if not is_dir_empty(x)]
    if not experiment_path_lst:
        return None
    # experiment_path_lst = sorted(experiment_path_lst, key=lambda x: int(str(x).split('_')[-1]))    
    return experiment_path_lst[-1]

def is_dir_empty(path: Path):
    is_empty = not any(path.iterdir())
    return is_empty

exp_dir = base_dir / "baselines/ViT/optimized/sipl-regev-imagenet_subset_1000/vit-b/experiments/perturbations/"

def get_perturb_results(method, pos_neg, target_top):
    layer2score = {}
    for i in range(12):
        tmp_dir = exp_dir / f"{method}_batched_layer{i}_{pos_neg}/{target_top}/"
        # if i == 11:
            # tmp_dir = exp_dir / f"{method}_batched_layer_{pos_neg}/{target_top}/"
        experiment_dir = get_experiment_dir_path(tmp_dir)
        if experiment_dir is None:
            print(f"Skipping {tmp_dir}")
            continue
        perturbations_hits_file = experiment_dir / "perturbations_hits.npy"
        model_hits_file = experiment_dir / "model_hits.npy"
        perturbations_hits = np.load(perturbations_hits_file)
        model_hits = np.load(model_hits_file)
        res = [model_hits.mean()] + perturbations_hits.mean(axis=-1).tolist()
        auc = metrics.auc(steps, res)
        auc = auc*100
        layer2score[i+1] = auc

    return layer2score

output_dir = base_dir / "figures/artifacts/supp/perturbation_predmap_per_layer"
output_dir.mkdir(exist_ok=True, parents=True)

scores = (
    list(get_perturb_results(method, "pos", "target").values())+
    list(get_perturb_results(method, "pos", "top").values())
    )
pos_min = min(scores)
pos_max = max(scores)

scores = (
    list(get_perturb_results(method, "neg", "target").values())+
    list(get_perturb_results(method, "neg", "top").values())
    )
neg_min = min(scores)
neg_max = max(scores)

pos_neg_min={
    "pos": pos_min,
    "neg": neg_min
}
pos_neg_max = {
    "pos": pos_max,
    "neg": neg_max
}

plt.close("all")
for pos_neg in ["pos", "neg"]:
    for target_top in ["target", "top"]:
        fig, ax = plt.subplots(figsize=(4,4.5))
        layer2score = get_perturb_results(method, pos_neg, target_top)
        ax.bar(layer2score.keys(), layer2score.values(), label=f"{pos_neg} - {target_top}")
        # ax.bar(layer2score.keys(), layer2score.values(), label="pos")
        ax.set_xlabel("Layer")
        ax.set_xticks(range(2, 12+1,2))
        # ax.set_yticks(np.arange(pos_neg_min[pos_neg]*0.9, pos_neg_max[pos_neg]*1.01,5))
        y_ticks = np.arange(pos_neg_min[pos_neg]//5*5, pos_neg_max[pos_neg]*1.5//5*5,5)
        ax.set_yticks(y_ticks)
        ax.set_ylabel("AUC")
        pos_neg_str = r"Positive $\downarrow$" if pos_neg == "pos" else r"Negative $\uparrow$"
        top_target_str = "Target" if target_top == "target" else "Predicted"
        title = f"{top_target_str} \n {pos_neg_str}"
        ax.set_title(f"{title}")
        ax.set_ylim([pos_neg_min[pos_neg]*0.9, pos_neg_max[pos_neg]*1.01])
        fig.tight_layout()
        # path = output_dir / f"{method}_{pos_neg}_{target_top}.png"
        # path = output_dir / f"{method}_{pos_neg}_{target_top}.pdf"
        path = output_dir / f"{method}_{pos_neg}_{target_top}.svg"
        print(path)
        fig.savefig(path, bbox_inches='tight', pad_inches=0.05)
        
        drawing = svg2rlg(path)
        renderPDF.drawToFile(drawing, str(path.with_suffix(".pdf")))

        # break
    # break

# %%
