import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('TkAgg')
import seaborn as sns
from matplotlib.lines import Line2D
import numpy as np


sns.set(font_scale=1, rc={'text.usetex' : True})
df = pd.read_csv("noise_tests_stats.csv", index_col=0)

df85 = df.groupby('goal').filter(lambda x: ((x['goal'] == 0.85).any()))
df90 = df.groupby('goal').filter(lambda x: ((x['goal'] == 0.9).any()))
print(df85)

palette = sns.color_palette("colorblind")[0:3]+sns.color_palette("colorblind")[-4:-1]
print(palette)
palette = dict({0: palette[0], 1: palette[3], 2: palette[2], 3: palette[1]})

fig, axes = plt.subplots(1, 3, figsize=(10, 3.5), gridspec_kw=dict(width_ratios=[2, 2, 0.5]), constrained_layout=True)


axes[0] = sns.lineplot(x='rounded_step', y='mean_norm', data=df85, hue="noise_factor", legend=False,
                          palette=palette, ax=axes[0])
categories = df85['noise_factor'].unique()
for category in categories:
    subset = df85[df85['noise_factor'] == category]
    axes[0].fill_between(subset['rounded_step'], subset['y_lower_norm'], subset['y_upper_norm'], alpha=0.2,
                         color=palette[category])
axes[0].set_ylabel('normalized return (smoothed)')
axes[0].set_xlabel('training steps (rounded to 500)')
title0 = axes[0].set_title(r'\textbf{(a)} $\quad P_{goal}=0.85$', fontsize=13)
title0.set_position(np.array([0.03, 0.99]))



axes[1] = sns.lineplot(x='rounded_step', y='mean_norm', data=df90, hue="noise_factor", legend=False,
                          palette=palette, ax=axes[1])
categories = df90['noise_factor'].unique()
for category in categories:
    subset = df90[df90['noise_factor'] == category]
    axes[1].fill_between(subset['rounded_step'], subset['y_lower_norm'], subset['y_upper_norm'], alpha=0.2,
                         color=palette[category])
axes[1].set_ylabel('normalized return (smoothed)')
axes[1].set_xlabel('training steps (rounded to 500)')
title1 = axes[1].set_title(r'\textbf{(b)} $\quad P_{goal}=0.9$', fontsize=13)
title1.set_position(np.array([0.02, 0.99]))


orig_pos = axes[2].get_position(original=True)
labels = ['0', '1', '2', '3']
handles = [Line2D([], [], color=palette[0]), Line2D([], [], color=palette[1]),
           Line2D([], [], color=palette[2]), Line2D([], [], color=palette[3])]
legend = fig.legend(handles, labels, loc='center left', bbox_to_anchor=(orig_pos.x0+0.07, orig_pos.y0+0.45),
                    title="noise factor")
fig.canvas.draw()
fig.delaxes(axes[2])


fig.figure.savefig("noise_plots.pdf", format="pdf", bbox_inches="tight")
plt.show()