#%%
import torch
import pandas as pd
import numpy as np
import lovely_tensors as lt
lt.monkey_patch()
from pathlib import Path

# %%

def remove_prefix(text:str, prefix:str):
    if text.startswith(prefix):
        return text[len(prefix):]
    return text

def dir2metrics(path: Path):
    try:
        res_path = next(path.rglob("result_mIoU*.txt"))
        # print(list(path.rglob("result_mIoU*.txt")))
        # print()
    except StopIteration:
        print(f"File not found in {path}/**/result_mIoU*.txt")
        return None
    text = res_path.read_text(encoding="utf-8")
    text = text.split("\n")
    miou = text[0].split(" ")[-1]
    miou = float(miou)
    pixAcc = text[1].split(" ")[-1]
    pixAcc = float(pixAcc[:-1])/100 # strip the % sign
    map = text[2].split(" ")[-1]
    map = float(map)
    return {"mIoU": miou, "pixAcc": pixAcc, "mAP": map}


prefix = "predmap22_vgg_layer"
# prefix = "predmap22_deit_layer"
# prefix = "predmap15_layer_vgg_layer"
# prefix="predmap15_softmax_classes_batched_layer_vgg_layer"
# prefix="predmap15_softmax_classes_tokens_batched_layer_vgg_layer"
# prefix="predmap15_softmax_tokens_batched_layer_vgg_layer"
# prefix="predmap31_softmax_classes_batched_layer_vgg_layer"
# prefix="predmap31_softmax_classes_tokens_batched_layer_vgg_layer"
# prefix="predmap31_softmax_tokens_batched_layer_vgg_layer"
base_dir = Path(__file__).parent.parent / "run/imagenet/"
experiments = (x for x in base_dir.iterdir() if x.name.startswith(prefix))
records = []
index = []
from tqdm import tqdm
for exp_path in tqdm(list(experiments)):
    layer = float(remove_prefix(exp_path.name, prefix))
    metrics = dir2metrics(exp_path)
    if metrics is None:
        print(f"Skipping {exp_path}")
        continue
    records.append(metrics)
    index.append(layer)
df = pd.DataFrame(records, index=index)
df.sort_index(inplace=True)
df.reset_index(inplace=True, names="layer")
metric_coulmn_names = ["pixAcc", "mAP", "mIoU"]
# multiply metric columns by 100
df[metric_coulmn_names] = df[metric_coulmn_names]*100
# cast layer to int
df['layer'] = df['layer'].astype(int)
display(" ")
print(f"{prefix}")
df[["layer"]+metric_coulmn_names]
# %%

s = 16
ax1 = df.plot.scatter(x='layer', y='mIoU', s=s, label="mIoU", color="r")
ax1 = df.plot.scatter(x='layer', y='pixAcc', label="pixAcc", color="g", ax=ax1, s=s)
ax1 = df.plot.scatter(x='layer', y='mAP', label="mAP", color="b", ax=ax1, s=s)
# mark the maximum value
df.loc[[df['mIoU'].idxmax()]].plot.scatter(x='layer', y='mIoU', marker="x", color="k", ax=ax1, s=s)
df.loc[[df['pixAcc'].idxmax()]].plot.scatter(x='layer', y='pixAcc',  marker="x", color="k", ax=ax1, s=s)
df.loc[[df['mAP'].idxmax()]].plot.scatter(x='layer', y='mAP', marker="x", color="y", ax=ax1, s=s)
ax1.set_ylabel("Score")
# ax1.set_xscale("log")
ax1.set_title(f"Scores per Layer\n{prefix}")


max_mIoU = df['mIoU'].max()
max_pixAcc = df['pixAcc'].max()
max_mAP = df['mAP'].max()

max_mIoU_temp = df[df['mIoU'] == max_mIoU]['layer'].values[0]
max_pixAcc_temp = df[df['pixAcc'] == max_pixAcc]['layer'].values[0]
max_mAP_temp = df[df['mAP'] == max_mAP]['layer'].values[0]

max_layers = {
    'mIoU': (max_mIoU, max_mIoU_temp),
    'pixAcc': (max_pixAcc, max_pixAcc_temp),
    'mAP': (max_mAP, max_mAP_temp)
}

max_layers_df = pd.DataFrame.from_dict(max_layers, orient='index', columns=['Max Score', 'Layer'])
max_layers_df
# %%
