import glob
import seaborn as sns
import colorcet as cc
import pandas as pd
import os
from os import path
from tqdm import tqdm
import matplotlib.lines as mlines

from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
import sys
import numpy as np
from matplotlib.ticker import FormatStrFormatter

import torch
import random
import matplotlib as mpl
from pathlib import Path
import matplotlib.patches as mpatches


mpl.rcParams["figure.dpi"] = 200
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
sns.set_theme()

# RANDOM_SCORE = 0.38
RANDOM_SCORE = 1.78


all_rlts = []

kind = sys.argv[1]

if kind == "collect":
    from tbparse import SummaryReader

    with tqdm(total=216) as pbar:
        for r, ds, fs in os.walk("./"):
            for f in fs:
                if len(fs) > 1:
                    continue
                if not f.startswith("events"):
                    continue
                log_dir = path.join(r, f)

                reader = SummaryReader(log_dir)
                dir = path.abspath(f"{path.dirname(log_dir)}/../events*")
                e2 = glob.glob(dir)[0]
                reader2 = SummaryReader(e2)

                hp = reader.hparams
                sc = reader.scalars
                tx = reader2.text
                records = {tag: hp[hp.tag == tag].value.item() for tag in hp.tag}
                records["best_scores"] = sc[sc.tag == "best_grades"].value.item()
                records["num_params"] = float(
                    tx[tx.tag == "num. of params."].value.item()[:-1]
                )
                all_rlts.append(records)

                pbar.update(1)

    ds = pd.DataFrame(all_rlts)
    ds.to_csv("./data.csv")
    exit(0)
else:
    ds = pd.read_csv("./exp_rlt/scaling_law.csv")


all_points = (
    ds.filter(items=["num_params", "shrink", "best_scores"])
    .groupby(by=["shrink", "num_params"])
    .mean()
    # .min()
)
best_point = (
    all_points.filter(items=["shrink", "best_scores"]).groupby(by=["shrink"]).min()
)
mean_point = (
    all_points.filter(items=["shrink", "best_scores"]).groupby(by=["shrink"]).mean()
)
# fig, ax = plt.subplots(figsize=(7.2, 4.8))
fig, (ax2, ax) = plt.subplots(
    2, 1, sharex=True, figsize=(6.8, 4.8), gridspec_kw={"height_ratios": (1, 6)}
)
# fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(10,8), gridspec_kw={'width_ratios': [3, 1]})


palette = [
    "#1984c5",
    "#22a7f0",
    "#63bff0",
    "#a7d5ed",
    "#e2e2e2",
    "#e1a692",
    "#de6e56",
    "#e14b31",
    "#c23728",
]

fig1 = sns.scatterplot(
    data=all_points,
    x="shrink",
    y="best_scores",
    hue="num_params",
    palette=palette,
    ax=ax,
)
fig1.set(xscale="log")
# g1.set(yscale="log")
# plt.legend()
# for t in legend.get_texts():
# ax.legend(np.unique(ds.filter(items=["num_params"]).values.tolist()))

plt.ylabel("Average Precision (m)")
ax.yaxis.set_label_coords(-.08, .62)
plt.xlabel("Shrink Level")
plt.xlim((0.8, 6000))
# plt.ylim((0, 0.55))
# h = [plt.plot([1.0, 2.0], [3.0, 4.0], "-or", ms=i, ls="")[0] for i in range(5, 13)]
pink_line = mlines.Line2D([], [], color="pink", label="best")
grey_line = mlines.Line2D([], [], color="#cfcfcf", marker="x", label="average")
random_line = mlines.Line2D(
    [], [], color="dimgrey", label="average", linestyle="dashed"
)
art = plt.legend(
    handles=[grey_line, pink_line, random_line],
    labels=["fitted average", "fitted best", "random policy"],
    loc=(0.275, 0.845),
    title="line",
)
art._legend_box.align = "left"

ax.add_artist(art)
plt.legend(
    title="num. of params.",
    loc=(0.015, 0.383),
)
num_params = np.unique(
    np.squeeze(ds.filter(items=["num_params"]).values, axis=1)
).tolist()
for i, t in enumerate(ax.legend_.texts):
    t.set_text(f"{'{:.2f}'.format(np.round(num_params[i], decimals=2))}M")
    # t.set_ha("left")
ax.get_legend()._legend_box.align = "left"


# plt.legend(legend_kwds=dict(fmt='{:.0f}', interval=True))
# fig2 = sns.regplot(
#     data=best_point, x="shrink", y="best_scores",  ax=ax
# )
# fig2.set(xscale="log")
_x = [k for k, v in best_point.to_dict()["best_scores"].items()]
_y = [v for k, v in best_point.to_dict()["best_scores"].items()]
fig2 = sns.regplot(
    x=_x,
    y=_y,
    color="pink",
    scatter=False,
    ax=ax,
)
# fig2.set(yscale="log")
# g2.set(yscale="log")
_y = [v for k, v in mean_point.to_dict()["best_scores"].items()]
fig3 = sns.regplot(x=_x, y=_y, color="#cfcfcf", ax=ax, scatter=True, marker="x")
fig3.set(xscale="log")
# fig3.set(yscale="log")
# fig3.set(yscale="log")

ax2.axhline(RANDOM_SCORE, color="dimgrey", linestyle="dashed")
# ig3 = sns.regplot(
#     data=mean_point, x="shrink", y="best_scores",  ax=ax
# )
# ax.set_xticklabels(map(str, [10, 100, 1000]))
ax.spines["top"].set_visible(False)
ax2.spines["bottom"].set_visible(False)
# ax.xaxis.tick_bottom()
ax.tick_params(labeltop=False)  # don't put tick labels at the top
# ax2.xaxis.tick_top()
# cks = np.unique(list(map(int, ds["shrink"].values.tolist())))
# ax.set_xticks(ticks)
# ax.set_xticklabels(map(str, ticks))
ax.tick_params(axis="x", which="minor", bottom=False)
ax.tick_params(axis="x", which="major", bottom=False)
ticks = np.unique(list(map(int, ds["shrink"].values.tolist())))
ax.set_xticks(ticks)
ax.set_xticklabels(map(str, ticks))
# 2.set_xticks(ticks)

plt.show()
Path("./imgs").mkdir(exist_ok=True, parents=True)
fig.savefig(f"./imgs/scaling_law.png", bbox_inches="tight", pad_inches=0.05)

plt.close()
