#%%
import torch
import pandas as pd
import numpy as np
import lovely_tensors as lt
lt.monkey_patch()
from pathlib import Path
base_dir = Path(__file__).parent.parent

# %%

def remove_prefix(text:str, prefix:str):
    if text.startswith(prefix):
        return text[len(prefix):]
    return text

def dir2metrics(path: Path):
    try:
        res_path = next(path.rglob("result_mIoU*.txt"))
        # print(list(path.rglob("result_mIoU*.txt")))
        # print()
    except StopIteration:
        print(f"File not found in {path}/**/result_mIoU*.txt")
        return None
    text = res_path.read_text(encoding="utf-8")
    text = text.split("\n")
    miou = text[0].split(" ")[-1]
    miou = float(miou)
    pixAcc = text[1].split(" ")[-1]
    pixAcc = float(pixAcc[:-1])/100 # strip the % sign
    map = text[2].split(" ")[-1]
    map = float(map)
    return {"mIoU": miou, "pixAcc": pixAcc, "mAP": map}


prefix="predmap15_softmax_classes_batched_layer_vgg_layer"
results_base_dir = base_dir / "run/imagenet/"
experiments = (x for x in results_base_dir.iterdir() if x.name.startswith(prefix))
records = []
index = []
from tqdm import tqdm
for exp_path in tqdm(list(experiments)):
    layer = float(remove_prefix(exp_path.name, prefix))
    metrics = dir2metrics(exp_path)
    if metrics is None:
        print(f"Skipping {exp_path}")
        continue
    records.append(metrics)
    index.append(layer)

def get_pure_predmap_df():
    data = [
        [0	,0.5309,	0.6818,	0.3325],
        [1	,0.5158,	0.6689,	0.3178],
        [2	,0.5262,	0.6775,	0.3269],
        [3	,0.5476,	0.6958,	0.3473],
        [4	,0.5857,	0.7240,	0.3837],
        [5	,0.6296,	0.7588,	0.4279],
        [6	,0.6706,	0.7878,	0.4707],
        [7	,0.7016,	0.8081,	0.5042],
        [8	,0.7314,	0.8266,	0.5382],
        [9	,0.7466,	0.8414,	0.5586],
        [10	,0.7129,	0.8181,	0.5288],
        [11	,0.5353,	0.5872,	0.3601],
    ]
    df = pd.DataFrame(data, columns=["l", "pixAcc", "mAP", "mIoU"])
    return df
df = get_pure_predmap_df()
df
# %%

# df = pd.DataFrame(records, index=index)
df.sort_index(inplace=True)
df.reset_index(inplace=True, names="layer")
metric_coulmn_names = ["pixAcc", "mAP", "mIoU"]
# multiply metric columns by 100
df[metric_coulmn_names] = df[metric_coulmn_names]*100
# cast layer to int
df['layer'] = df['layer'].astype(int)+1
print(f"{prefix}")
df[["layer"]+metric_coulmn_names]
# %%

from matplotlib import pyplot as plt
plt.rcParams.update({
    "text.usetex": True,  # Enables latex equations
    "font.family": "cmu-serif", # Sets the correct font
    "mathtext.fontset": "cm",   # --"--
    "font.size": 20,            # Set the font according to what you need
    "text.latex.preamble": r"\usepackage{amsmath}"   # You can add this to enable complicated math stuff
})
plt.style.use('tableau-colorblind10')  # You can use this to get a colorblind color palette

s = 20
fig, ax1 = plt.subplots(figsize=(6,5))
ax1 = df.plot.scatter(x='layer', y='mIoU', label="mIoU", marker="1", color="r", ax=ax1, s=s)
ax1 = df.plot.scatter(x='layer', y='pixAcc', label="pixAcc", marker="s", color="g", ax=ax1, s=s)
ax1 = df.plot.scatter(x='layer', y='mAP', label="mAP", marker="*", color="b", ax=ax1, s=s)
# mark the maximum value
df.loc[[df['mIoU'].idxmax()]].plot.scatter(x='layer', y='mIoU', marker="o", s=s*5, color='none', facecolors='none', edgecolors='r', ax=ax1)
# df.loc[[df['pixAcc'].idxmax()]].plot.scatter(x='layer', y='pixAcc',  marker="x", color="k", ax=ax1, s=s)
df.loc[[df['pixAcc'].idxmax()]].plot.scatter(x='layer', y='pixAcc', marker="o", s=s*5, color='none', facecolors='none', edgecolors='g', ax=ax1)
df.loc[[df['mAP'].idxmax()]].plot.scatter(x='layer', y='mAP', marker="o", s=s*5, color='none', facecolors='none', edgecolors='b', ax=ax1)
ax1.set_ylabel("Score")
ax1.set_xlabel("Layer")
ax1.set_xticks(range(2, 12+1,2))
ax1.set_title(r"Prediction map")
ax1.legend(fontsize=14)



output_dir: Path = base_dir / "figures/artifacts/supp/segmentation_predmap_per_layer"
output_dir.mkdir(exist_ok=True, parents=True)

fig = ax1.figure
fig.tight_layout()
path = output_dir / f"ssupp_egmentation_predmap_per_layer.svg"
fig.savefig(path, bbox_inches='tight', pad_inches=0.05)

from svglib.svglib import svg2rlg
from reportlab.graphics import renderPDF
drawing = svg2rlg(path)
renderPDF.drawToFile(drawing, str(path.with_suffix(".pdf")))

# %%
