# matplotlib stuff
import numpy as np

import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
import argparse

import os

plt.rcParams['axes.xmargin'] = 0
plt.rcParams['axes.ymargin'] = 0
matplotlib.rcParams.update({'font.size': 50})

env_name = "AdaptChain"

parser = argparse.ArgumentParser()
# parser.add_argument("--lr", type = float, help = "learning rate")
parser.add_argument("--runs", type = int, help = "number of runs")
parsed_args = parser.parse_args()
# lr = parsed_args.lr
runs = parsed_args.runs

lrs = [0.1, 0.01, 0.001, 0.0001]

print(f"runs = {runs}")

data_dir = "data/grad_check/"
fig_dir = f"figs/grad_check/{env_name}/"
os.makedirs(fig_dir, exist_ok=True)

colours = ["r", "g", "b", "orange", "purple"]

for plot_type in ["cossim", "truepgcossim"]:
    # make two versions of the plots: regular and close-up
    for bottom_limit in [None]:
        figsize = (20, 15)
        plt.figure(figsize = figsize)
        for lr in lrs:
            cos_sim = np.load(f"{data_dir}{env_name}_lr={lr}_{plot_type}.npy")
            if plot_type == "truepgcossim":
                reinforce_cos_sim = np.load(f"{data_dir}{env_name}_lr={lr}_reinforcecossim.npy")
                # print(reinforce_cos_sim)
                plt.axhline(xmin = 0, xmax = cos_sim.shape[1], y = reinforce_cos_sim.mean(), color = colours[lrs.index(lr)], linestyle = "dotted", linewidth=7)
            # for run in range(runs):
                #     plt.plot(cos_sim[run, :], color = colours[lrs.index(lr)], alpha = 0.3, label = f"lr = {lr}", linewidth = 2)
            mean = np.nanmean(cos_sim, axis = 0)
            plt.plot(mean, color = colours[lrs.index(lr)], label = f"lr = {lr}", linewidth = 2)
            num_non_nans = np.count_nonzero(~np.isnan(cos_sim), axis = 0)
            std = np.nanstd(cos_sim, axis = 0)
            print(f"max standard error = {np.nanmax(std / np.sqrt(num_non_nans))}")
            # plt.fill_between(mean - std / np.sqrt(num_non_nans), mean + std / np.sqrt(num_non_nans), alpha = 0.3, color = colours[lrs.index(lr)])
        # plt.legend([
        #     Line2D([0], [0], color = colour, lw = 4) for colour in colours
        # ], lrs)
        plt.xlabel("Timesteps")
        plt.ylabel("Cosine Similarity")
        plt.ylim(bottom = bottom_limit)
        plt.title(f"{env_name}")
        plt.savefig(f"{fig_dir}{env_name}_bot={bottom_limit}_{plot_type}.png")
        plt.close()
    # make separate plots for lrs too
    for bottom_limit in [None]:
        for lr in lrs:
            figsize = (20, 15)
            plt.figure(figsize = figsize)
            cos_sim = np.load(f"{data_dir}{env_name}_lr={lr}_{plot_type}.npy")
            for run in range(runs):
                plt.plot(cos_sim[run, :], color = colours[lrs.index(lr)], alpha = 0.3, linewidth = 2)
            plt.xlabel("Timesteps")
            plt.ylabel("Cosine Similarity")
            plt.ylim(bottom = bottom_limit)
            plt.title(f"{env_name}, lr = {lr}")
            plt.savefig(f"{fig_dir}{env_name}_bot={bottom_limit}_lr={lr}_{plot_type}.png")
            plt.close()