# matplotlib stuff
import numpy as np

import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
import argparse

import os

plt.rcParams['axes.xmargin'] = 0
plt.rcParams['axes.ymargin'] = 0
matplotlib.rcParams.update({'font.size': 45})

env_name = "CartPole-v1"

parser = argparse.ArgumentParser()
# parser.add_argument("--lr", type = float, help = "learning rate")
parser.add_argument("--runs", type = int, help = "number of runs")
parser.add_argument("--hidden", type = int, help = "hidden layer size")
parsed_args = parser.parse_args()
# lr = parsed_args.lr
runs = parsed_args.runs
hidden = parsed_args.hidden

lrs = [0.01, 0.001, 0.0001, 0.00001]

print(f"hidden = {hidden} | runs = {runs}")

data_dir = "data/grad_check/"
fig_dir = f"figs/grad_check/{env_name}/"
os.makedirs(fig_dir, exist_ok=True)

colours = ["r", "g", "b", "orange", "purple"]

print(f"env = {env_name} | lrs = {lrs}")

for bottom_limit in [None]:
    figsize = (20, 15)
    plt.figure(figsize = figsize)
    # make two versions of the plots: regular and close-up
    for lr in lrs:
        cos_sim = np.load(f"{data_dir}{env_name}_lr={lr}_hidden={hidden}_cossim.npy")
        for run in range(runs):
            plt.plot(cos_sim[run, :], color = colours[lrs.index(lr)], alpha = 0.3, label = f"lr = {lr}", linewidth = 2)
    plt.legend([
        Line2D([0], [0], color = colour, lw = 4) for colour in colours
    ], lrs, title = "LRs")
    plt.xlabel("Timesteps")
    plt.ylabel("Cosine Similarity of the Traces")
    plt.ylim(bottom = bottom_limit)
    plt.title(f"{env_name}, hidden layer = {hidden}")
    plt.savefig(f"{fig_dir}{env_name}_hidden={hidden}_bot={bottom_limit}_runs={runs}.png")

    # mean plot
    plt.close()
    figsize = (20, 15)
    plt.figure(figsize = figsize)
    # make two versions of the plots: regular and close-up
    for lr in lrs:
        cos_sim = np.load(f"{data_dir}{env_name}_lr={lr}_hidden={hidden}_cossim.npy")
        mean = np.nanmean(cos_sim, axis = 0)
        std = np.nanstd(cos_sim, axis = 0)
        plt.plot(mean, color = colours[lrs.index(lr)], label = f"lr = {lr}", linewidth = 2)
        num_non_nans = np.count_nonzero(~np.isnan(cos_sim), axis = 0)
        print(num_non_nans)
        print(f"max standard error = {np.nanmax(std / np.sqrt(num_non_nans))}")
        # plt.fill_between(mean - std / np.sqrt(runs), mean + std / np.sqrt(runs), alpha = 0.3, color = colours[lrs.index(lr)])
    # plt.legend([
        # Line2D([0], [0], color = colour, lw = 4) for colour in colours
    # ], lrs, title = "LRs")
    plt.xlabel("Timesteps")
    plt.ylabel("Cosine Similarity of the Traces")
    plt.ylim(bottom = bottom_limit)
    plt.title(f"{env_name}, hidden layer = {hidden}")
    plt.savefig(f"{fig_dir}{env_name}_means_hidden={hidden}_bot={bottom_limit}_runs={runs}.png")