import glob
import seaborn as sns
import colorcet as cc
import pandas as pd
import os
from os import path
from tqdm import tqdm
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
import sys
import numpy as np
import torch
import random
import matplotlib as mpl
import matplotlib.lines as mlines
from pathlib import Path

mpl.rcParams["figure.dpi"] = 200
seed = 0
np.random.seed(seed)
sns.set_theme()
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)


all_rlts = []
no_mask_score = np.mean([0.5961807469526926, 0.5844251612822214, 0.577358235915502])
full_mask_score = np.mean([0.5729121511632745, 0.5729121511632745, 0.5593602359294891])

kind = sys.argv[1]
CNT = 101
sns.set_theme()

if kind == "collect":
    total = int(sys.argv[2])
    from tbparse import SummaryReader

    with tqdm(total=total) as pbar:
        for r, ds, fs in os.walk("./"):
            for f in fs:
                if len(fs) > 1:
                    continue
                if not f.startswith("events"):
                    continue
                log_dir = path.join(r, f)

                reader = SummaryReader(log_dir)
                dir = path.abspath(f"{path.dirname(log_dir)}/../events*")
                e2 = glob.glob(dir)[0]
                reader2 = SummaryReader(e2)

                hp = reader.hparams
                sc = reader.scalars

                sc2 = reader2.scalars
                # tx = reader2.text
                scalars = {
                    "eval/section/full-mask",
                    "eval/section/half-mask",
                    "eval/section/no-mask",
                }

                for i in range(
                    CNT
                    # if log_dir
                    # != "./Sep23_12-12-00_other-479a41c2-vvbdc_548947_modality_3_shrink_1/1695450213.0808315/events.out.tfevents.1695450213.other-479a41c2-vvbdc.23.1"
                    # else CNT - 1
                ):
                    records = {tag: hp[hp.tag == tag].value.item() for tag in hp.tag}
                    records["best_scores"] = sc[sc.tag == "best_grades"].value.item()
                    for k in scalars:
                        records[k] = sc2[sc2.tag == k].value.tolist()[i]
                        # try:
                        #     records[k] = sc2[sc2.tag == k].value.tolist()[i]
                        # except:
                        #     print(f"error happended: i: {i}, k: {k}, log_dir: {log_dir}")
                        #     exit(1)
                    all_rlts.append(records)

                pbar.update(1)

    ds = pd.DataFrame(all_rlts)
    ds.to_csv("./data.csv")
    exit(0)
else:
    ds = pd.read_csv("./exp_rlt/modality.csv")


# metric = "eval/section/full-mask"
# metric = "eval/section/"
metric_part = sys.argv[2]
assert metric_part in ["full-mask", "no-mask"]
metric = f"eval/section/{metric_part}"
all_points = (
    ds[["num_centers", metric, "seed"]]
    .groupby(["num_centers", "seed"])
    .min()
    # .groupby(["num_centers"])
    # .quantile(0.25)
    # .mean()
)
fig, (ax2, ax) = plt.subplots(
    2, 1, sharex=True, figsize=(6.8, 4.2), gridspec_kw={"height_ratios": (1, 6)}
)
plt.xscale("log")
ax.spines["top"].set_visible(False)
ax2.spines["bottom"].set_visible(False)
ax.tick_params(labeltop=False)  # don't put tick labels at the top

ticks = [1, 2, 5, 10, 32, 48, 64, 96]
ax.set_xticks(ticks)
ax.set_xticklabels(map(str, ticks))

ax2.axhline(
    full_mask_score if metric_part == "full-mask" else no_mask_score,
    color="dimgrey",
    linestyle="dashed",
)
ax2.set_ylim((0.55, 0.6))
ax2.set_yticks((0.55, 0.6))
fig1 = sns.lineplot(
    data=all_points,
    err_style="bars",
    x="num_centers",
    y=metric,
    # hue="num_params",
    marker="o",
    ax=ax,
)
plt.ylabel("Precision (radian)")
plt.xlabel("#. of action bins")
ax.yaxis.set_label_coords( -0.105, 0.62)
# ax.set_xticklabels(["Yes", "No"])
# plt.legend(
#     title="num. of params.",
#     loc="upper left",
# )
blue_line = mlines.Line2D([], [], color="#4e72ab", label="best")
random_line = mlines.Line2D(
    [], [], color="dimgrey", label="average", linestyle="dashed"
)
art = plt.legend(
    handles=[blue_line, random_line],
    labels=["RoBERT", "rando policy"],
    loc=(0.01, 0.845),
    title="line",
)
art._legend_box.align = "left"

ax.add_artist(art)
plt.show()
Path("./imgs").mkdir(exist_ok=True, parents=True)
fig.savefig(f"./imgs/modality_{metric_part}.png", bbox_inches="tight", pad_inches=0.05)

plt.close()
