import glob
import seaborn as sns
import colorcet as cc
import pandas as pd
import os
from os import path
from tqdm import tqdm
import matplotlib.lines as mlines

from args import args
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
import sys
import numpy as np
from matplotlib.ticker import FormatStrFormatter

import torch
import random
from tqdm import tqdm
import matplotlib as mpl
from pathlib import Path
import matplotlib.patches as mpatches
import shutil

mpl.rcParams["figure.dpi"] = 200
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
sns.set_theme()


RANDOM_SCORE = 0.38


all_rlts = []

task = args.task
eval_name = args.eval_name

path_dir = f"./eval_video/{task}/{eval_name}"

achieved_pos = torch.load(os.path.join(path_dir, "./aps.pt"))
command_pos = torch.load(os.path.join(path_dir, "./cps.pt"))
masks = torch.load(os.path.join(path_dir, "./masks.pt"))
mae = (achieved_pos[1:, 0, :] - command_pos[:, :]).abs().mean(dim=1).to("cpu")
if masks is not None:
    _idx = masks.argwhere().flatten()
    _idx = _idx[_idx < mae.size(0)]
    mae[_idx] = 0.0
mae = [0] + mae.tolist()

ds2 = pd.DataFrame(
    {
        "mae": mae,
        "second": map(
            lambda f: float(f / 20),
            list(range(len(mae))),
        ),
    }
)


Path(f"./imgs/task/{task}").mkdir(exist_ok=True, parents=True)

plt.figure(figsize=(4.8, 4.8))
fig = sns.lineplot(data=ds2, x="second", y="mae", marker="o", markersize=2)
plt.xlabel("Second")
plt.ylabel(f"Precision ({'m' if task == 'pointmaze' else 'radian'})")
plt.ticklabel_format(axis="y", scilimits=(-2, 3))
# plt.show()
fig.get_figure().savefig(
    f"./imgs/task/{task}/{eval_name}_mae.png", bbox_inches="tight", pad_inches=0.05
)

plt.close()

_path = f"eval_video/{task}/{eval_name}/tracking_imgs"
path = Path(_path)
if os.path.exists(_path):
    shutil.rmtree(_path)
path.mkdir(exist_ok=True, parents=True)

images = []
for i in tqdm(range(len(mae))):
    plt.figure(figsize=(3.2, 3.2))
    fig = sns.lineplot(data=ds2, x="second", y="mae", lw=0.7)
    plt.xlabel("Second")
    plt.ylabel("Precision")
    plt.ticklabel_format(axis="y", scilimits=(-2, 3))

    plt.axvline(x=(i + 1) / 20, lw=0.8, alpha=0.5, color="red")
    # plt.show()
    fig.get_figure().savefig(
        f"eval_video/{task}/{eval_name}/tracking_imgs/img_{(i+1)}",
        bbox_inches="tight",
        pad_inches=0.05,
    )

    plt.close()
