import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import argparse
sys.path.append("../src")

plt.style.use('../plot_params.dms')
units_convert = {'cm': 1 / 2.54, 'mm': 1 / 2.54 / 10}

parser = argparse.ArgumentParser()
parser.add_argument("--n_seeds", type=int, help="number of seeds", default=2)
args = parser.parse_args()

losses = {"Fully trained": {0.9: [], 1: []},
          "Bias learning": {0.9: [], 1: []}}

# Folder for result
result_folder = "../results/vanderpol"
if not os.path.exists(result_folder):
    os.makedirs(result_folder)

# Load data
seeds = np.arange(1, args.n_seeds+1)
data_folders = {"Fully trained": {0.9: [f"../data/vanderpol/hiddensize25-lr0.0001-biaslearningFalse-biasinituniform-epochs30000-gain0.9-gainout1.0-seed{i}" for i in seeds],
                                  1: [f"../data/vanderpol/hiddensize25-lr0.0001-biaslearningFalse-biasinituniform-epochs30000-gain1.0-gainout1.0-seed{i}" for i in seeds]},
                "Bias learning": {0.9: [f"../data/vanderpol/hiddensize675-lr0.1-biaslearningTrue-biasinituniform-epochs5000-gain0.9-gainout1.0-seed{i}" for i in seeds],
                                  1: [f"../data/vanderpol/hiddensize675-lr0.1-biaslearningTrue-biasinituniform-epochs5000-gain1.0-gainout1.0-seed{i}" for i in seeds]}}

for type_ in losses.keys():
    for gain in losses[type_].keys():
        for seed in seeds:
            losses[type_][gain].append(np.load(os.path.join(data_folders[type_][gain][seed-1], 'loss.npy')))
        losses[type_][gain] = -np.log10(np.array(losses[type_][gain]))

fig, axes = plt.subplots(ncols=2, figsize=(60*units_convert['mm'], 45/1.25*units_convert['mm']), sharey=True, sharex=True)
# plot for fully trained
yerr = []
height = []
x = []
for gain in losses["Fully trained"].keys():
    x.append(gain)
    yerr.append(np.std(losses["Fully trained"][gain][:, -1], ddof=1) / len(seeds)**0.5)
    height.append(np.mean(losses["Fully trained"][gain][:, -1]) / len(seeds)**0.5)
axes[0].bar(x, height, width=0.08, yerr=yerr, color=(0, 0, 225/255))
axes[0].set_ylabel("Log training performance")
axes[0].set_xlabel("Gain recurrent init.")
axes[0].set_title("Fully trained")
axes[0].set_xticks(x)
# plot for bias learning
yerr = []
height = []
for gain in losses["Bias learning"].keys():
    yerr.append(np.std(losses["Bias learning"][gain][:, -1], ddof=1) / len(seeds)**0.5)
    height.append(np.mean(losses["Bias learning"][gain][:, -1]) / len(seeds)**0.5)
axes[1].bar(x, height, width=0.08, yerr=yerr, color=[(230/255, 97/255, 0), 'orange'])
axes[1].set_xlabel("Gain recurrent init.")
axes[1].set_title("Bias learning")
plt.tight_layout()
plt.savefig(os.path.join(result_folder, "Performance_vs_gain_PanelD.png"))