#%%
import torch
import pandas as pd
import numpy as np
import lovely_tensors as lt
lt.monkey_patch()
from pathlib import Path
base_dir = Path(__file__).parent.parent

# %%

def remove_prefix(text:str, prefix:str):
    if text.startswith(prefix):
        return text[len(prefix):]
    return text

def dir2metrics(path: Path):
    try:
        res_path = next(path.rglob("result_mIoU*.txt"))
        # print(list(path.rglob("result_mIoU*.txt")))
        # print()
    except StopIteration:
        print(f"File not found in {path}/**/result_mIoU*.txt")
        return None
    text = res_path.read_text(encoding="utf-8")
    text = text.split("\n")
    miou = text[0].split(" ")[-1]
    miou = float(miou)
    pixAcc = text[1].split(" ")[-1]
    pixAcc = float(pixAcc[:-1])/100 # strip the % sign
    map = text[2].split(" ")[-1]
    map = float(map)
    return {"mIoU": miou, "pixAcc": pixAcc, "mAP": map}


prefix="predmap15_softmax_classes_batched_layer_vgg_layer"
results_base_dir = base_dir / "run/imagenet/"
experiments = (x for x in results_base_dir.iterdir() if x.name.startswith(prefix))
records = []
index = []
from tqdm import tqdm
for exp_path in tqdm(list(experiments)):
    layer = float(remove_prefix(exp_path.name, prefix))
    metrics = dir2metrics(exp_path)
    if metrics is None:
        print(f"Skipping {exp_path}")
        continue
    records.append(metrics)
    index.append(layer)


df = pd.DataFrame(records, index=index)
df.sort_index(inplace=True)
df.reset_index(inplace=True, names="layer")
metric_coulmn_names = ["pixAcc", "mAP", "mIoU"]
# multiply metric columns by 100
df[metric_coulmn_names] = df[metric_coulmn_names]*100
# cast layer to int
df['layer'] = df['layer'].astype(int)+1
print(f"{prefix}")
df[["layer"]+metric_coulmn_names]
# %%

from matplotlib import pyplot as plt
plt.rcParams.update({
    "text.usetex": True,  # Enables latex equations
    "font.family": "cmu-serif", # Sets the correct font
    "mathtext.fontset": "cm",   # --"--
    "font.size": 30,            # Set the font according to what you need
    "text.latex.preamble": r"\usepackage{amsmath}"   # You can add this to enable complicated math stuff
})
plt.style.use('tableau-colorblind10')  # You can use this to get a colorblind color palette

s = 20
fig, ax1 = plt.subplots(figsize=(6,5))
ax1 = df.plot.scatter(x='layer', y='mIoU', label="mIoU", marker="1", color="r", ax=ax1, s=s)
ax1 = df.plot.scatter(x='layer', y='pixAcc', label="pixAcc", marker="s", color="g", ax=ax1, s=s)
ax1 = df.plot.scatter(x='layer', y='mAP', label="mAP", marker="*", color="b", ax=ax1, s=s)
# mark the maximum value
df.loc[[df['mIoU'].idxmax()]].plot.scatter(x='layer', y='mIoU', marker="o", s=s*5, color='none', facecolors='none', edgecolors='r', ax=ax1)
# df.loc[[df['pixAcc'].idxmax()]].plot.scatter(x='layer', y='pixAcc',  marker="x", color="k", ax=ax1, s=s)
df.loc[[df['pixAcc'].idxmax()]].plot.scatter(x='layer', y='pixAcc', marker="o", s=s*5, color='none', facecolors='none', edgecolors='g', ax=ax1)
df.loc[[df['mAP'].idxmax()]].plot.scatter(x='layer', y='mAP', marker="o", s=s*5, color='none', facecolors='none', edgecolors='b', ax=ax1)
ax1.set_ylabel("Score", fontsize=30)
ax1.set_xlabel("Layer $i$", fontsize=30)
ax1.set_xticks(range(2, 12+1,2))
ax1.set_ylim(25,90)
ax1.set_title(r"$\text{PredicAtt}_i$", fontsize=40)
ax1.legend(fontsize=24, borderpad=0.1, labelspacing=0.1, handletextpad=0,borderaxespad=0.1)



output_dir: Path = base_dir / "figures/artifacts/segmentation_per_layer"
output_dir.mkdir(exist_ok=True, parents=True)

fig = ax1.figure
fig.tight_layout()
path = output_dir / f"segmentation_per_layer.svg"
fig.savefig(path, bbox_inches='tight', pad_inches=0.05)


from svglib.svglib import svg2rlg
from reportlab.graphics import renderPDF
drawing = svg2rlg(path)
renderPDF.drawToFile(drawing, str(path.with_suffix(".pdf")))

# %%



max_mIoU = df['mIoU'].max()
max_pixAcc = df['pixAcc'].max()
max_mAP = df['mAP'].max()

max_mIoU_temp = df[df['mIoU'] == max_mIoU]['layer'].values[0]
max_pixAcc_temp = df[df['pixAcc'] == max_pixAcc]['layer'].values[0]
max_mAP_temp = df[df['mAP'] == max_mAP]['layer'].values[0]

max_layers = {
    'mIoU': (max_mIoU, max_mIoU_temp),
    'pixAcc': (max_pixAcc, max_pixAcc_temp),
    'mAP': (max_mAP, max_mAP_temp)
}

max_layers_df = pd.DataFrame.from_dict(max_layers, orient='index', columns=['Max Score', 'Layer'])
max_layers_df
# %%
plt.figure()
ax1 = df.plot.scatter(x='layer', y='mIoU', s=80, color='none', facecolors='none', edgecolors='r')