import glob
import seaborn as sns
import colorcet as cc
import pandas as pd
import os
from os import path
from tqdm import tqdm
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
import sys
import numpy as np
import torch
import random
import matplotlib as mpl
from pathlib import Path
import pickle

mpl.rcParams["figure.dpi"] = 200
seed = 0
np.random.seed(seed)
sns.set_theme()
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)


all_rlts = []

kind = sys.argv[1]
CNT = 101
sns.set_theme()
no_mask_score = np.mean([0.5961807469526926, 0.5844251612822214, 0.577358235915502])
full_mask_score = np.mean([0.5729121511632745, 0.5729121511632745, 0.5593602359294891])

if kind == "collect":
    total = int(sys.argv[2])
    from tbparse import SummaryReader

    with tqdm(total=total) as pbar:
        for r, ds, fs in os.walk("./"):
            for f in fs:
                if len(fs) > 1:
                    continue
                if not f.startswith("events"):
                    continue
                log_dir = path.join(r, f)

                reader = SummaryReader(log_dir)
                dir = path.abspath(f"{path.dirname(log_dir)}/../events*")
                e2 = glob.glob(dir)[0]
                reader2 = SummaryReader(e2)

                hp = reader.hparams
                sc = reader.scalars

                sc2 = reader2.scalars
                # tx = reader2.text
                scalars = {
                    "eval/section/full-mask",
                    "eval/section/half-mask",
                    "eval/section/no-mask",
                }

                for i in range(CNT):
                    records = {tag: hp[hp.tag == tag].value.item() for tag in hp.tag}
                    records["best_scores"] = sc[sc.tag == "best_grades"].value.item()
                    for k in scalars:
                        records[k] = sc2[sc2.tag == k].value.tolist()[i]
                    all_rlts.append(records)

                pbar.update(1)

    ds = pd.DataFrame(all_rlts)
    ds.to_csv("./data.csv")
    exit(0)
else:
    ds = pd.read_csv("./exp_rlt/mask.csv")


metric_part = sys.argv[2]
assert metric_part in ["full-mask", "no-mask"]
metric = f"eval/section/{metric_part}"
ds["no_mask"] = ds["no_mask"].astype("string")
# metric = "best_scores"
all_points = ds[["no_mask", metric, "seed"]].groupby(by=["no_mask", "seed"]).min()
fig, (ax2, ax) = plt.subplots(
    2, 1, sharex=True, figsize=(6.8, 4.2), gridspec_kw={"height_ratios": (1, 6)}
)

palette = [
    "#1984c5",
    "#22a7f0",
    "#63bff0",
    "#de6e56",
    "#e14b31",
    "#c23728",
]


def change_label(no_mask: str):
    if no_mask == "0.0":
        return "CL"
    elif no_mask == "1.0":
        return "No"
    else:
        return str(int(float(no_mask) * 100)) + "%"


_ds = all_points.to_dict()
_ds = [(change_label(no_mask), s) for (no_mask, _), s in _ds[metric].items()] + [
    ("Random", full_mask_score if metric_part == "full-mask" else no_mask_score)
]

ds2 = pd.DataFrame(_ds, columns=["no_mask", metric])

fig1 = sns.barplot(
    data=ds2,
    x="no_mask",
    # marker="o",
    ax=ax,
    y=metric,
    # hue="num_centers",
    # palette=palette,
)
fig2 = sns.barplot(
    data=ds2,
    x="no_mask",
    # marker="o",
    ax=ax2,
    y=metric,
    # hue="num_centers",
    # palette=palette,
)
ax.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.set_ylim((0.55, 0.6))
ax2.set_yticks([0.55, 0.6])
ax.set_ylim((0, 0.25 if metric_part == "full-mask" else 0.15))
ax.spines["top"].set_visible(False)
ax2.spines["bottom"].set_visible(False)
ax2.set_xlabel("")
ax2.set_ylabel("")
# plt.xscale("log")
# plt.legend()
# for t in legend.get_texts():
# ax.legend(np.unique(ds.filter(items=["num_params"]).values.tolist()))

plt.ylabel("Precision (radian)")
plt.xlabel("Mask")


ax.yaxis.set_label_coords(-0.09, 0.62)
# ticks = ["CL", "0", "0.1", "0.15", "0.3", "0.45", "0.65", "No"]
# fig1.set_xticks(ticks)
# fig1.set_xticklabels(map(str, ticks))

# plt.legend(
#     title="num. of centers",
#     loc="upper left",
# )
# ax.get_legend()._legend_box.align = "left"
# ax.set_ylim((0.013, 0.024))
# plt.legend(legend_kwds=dict(fmt='{:.0f}', interval=True))
# fig2 = sns.regplot(
#     data=best_point, x="shrink", y="best_scores",  ax=ax
# )
# ticks = np.unique(list(map(int, ds["shrink"].values.tolist())))
# ax.set_xticks(ticks)
# ax.set_xticklabels(map(str, ticks))
# ax.tick_params(axis="x", which="minor", bottom=False)
# ticks = np.unique(list(map(int, ds["num_centers"].values.tolist())))
# ax.set_xticks(ticks)
# ax.set_xticklabels(map(str, ticks))
# ax.tick_params(axis="x", which="minor", bottom=False)
plt.show()
Path("./imgs").mkdir(exist_ok=True, parents=True)
fig.savefig(f"./imgs/mask_{metric_part}.png", bbox_inches="tight", pad_inches=0.05)

plt.close()
