import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
import numpy as np
matplotlib.use('TkAgg')
sns.set(font_scale=1, rc={'text.usetex': True, "font.family": "serif", "font.sans-serif": "Times"})
df_diff_obs = pd.read_csv("diff_obs_stats.csv", index_col=0)
df_obs = pd.read_csv("number_obs_stats.csv", index_col=0)
df_return_abs_pos = pd.read_csv("abs_pos_obs_stats.csv", index_col=0)
df_abs_pos = pd.read_csv("test_abs_pos_obs.csv", index_col=0)
df_abs_pos["number_shifts"] = 4 - (df_abs_pos["mean_xm1_shift"].isin([0]).astype(int)
                                   + df_abs_pos["mean_ym1_shift"].isin([0]).astype(int)
                                    + df_abs_pos["mean_xm2_shift"].isin([0]).astype(int)
                                   + df_abs_pos["mean_ym2_shift"].isin([0]).astype(int))
palette = sns.color_palette("colorblind")
print(palette)
liste = [(2000, palette[0])]
for i in range(1, 8):
    liste.append(((i+3)*10**3, palette[i]))
palette_act = dict(liste)
liste =[]
for i in range(6):
    liste.append((i+1, palette[i]))
liste.append((30, palette[-1]))
palette_obs = dict(liste)
palette_pre = dict({True: palette[0], False: palette[1]})

palette = sns.color_palette("colorblind")[0:3]+sns.color_palette("colorblind")[-4:-1]
palette_diff_obs = dict({"P_ave_0_P_max_0_x_max_0": palette[0], "P_ave_0_P_max_1_x_max_1": palette[1],
                    "P_ave_1_P_max_0_x_max_0": palette[5], "P_ave_1_P_max_1_x_max_1": palette[2]})


fig, axes = plt.subplots(2, 2, figsize=(10, 9), constrained_layout=True)
axes = axes.flatten()




axes[0] = sns.lineplot(x='rounded_step', y='mean_norm', data=df_obs, hue="number_obs", legend="full",
                          palette=palette_obs, ax=axes[0])
categories = df_obs['number_obs'].unique()
for category in categories:
    subset = df_obs[df_obs['number_obs'] == category]
    axes[0].fill_between(subset['rounded_step'], subset['y_lower_norm'], subset['y_upper_norm'], alpha=0.2,
                         color=palette_obs[category])
axes[0].set_ylabel('normalized return (smoothed)')
axes[0].set_xlabel('training steps (rounded to 500)')
axes[0].legend(ncol=2, title="history length $n$")
title0 = axes[0].set_title(r'\textbf{(a)}', fontsize=13)
title0.set_position(np.array([-0.125, 0.99]))


axes[1] = sns.lineplot(data=df_diff_obs, x="rounded_step", y="mean_norm", hue="obs_kind", legend="full",
                          palette=palette_diff_obs, ax=axes[1])
categories = df_diff_obs['obs_kind'].unique()
for category in categories:
    subset = df_diff_obs[df_diff_obs['obs_kind'] == category]
    axes[1].fill_between(subset['rounded_step'], subset['y_lower_norm'], subset['y_upper_norm'], alpha=0.2,
                         color=palette_diff_obs[category])
axes[1].set_ylabel('normalized return (smoothed)')
axes[1].set_xlabel('training steps (rounded to 500)')
new_labels1 = ['without $P_{ave}, P_{max}, x_{max}$', 'incl. $P_{max}, x_{max}$', 'incl. $P_{ave}$',
  'incl. $P_{ave}, P_{max}, x_{max}$']
handles1, labels1 = axes[1].get_legend_handles_labels()
axes[1].legend(handles=handles1, labels=new_labels1, title="observation")
title1 = axes[1].set_title(r'\textbf{(b)}', fontsize=13)
title1.set_position(np.array([-0.125, 0.99]))

axes[2].set_xlim(0, 100000)
axes[2] = sns.lineplot(x='rounded_step', y='mean_norm', data=df_return_abs_pos, hue="abs_pos_obs", legend="full",
                          palette=palette_pre, ax=axes[2])
categories = df_return_abs_pos['abs_pos_obs'].unique()
for category in categories:
    subset = df_return_abs_pos[df_return_abs_pos['abs_pos_obs'] == category]
    axes[2].fill_between(subset['rounded_step'], subset['y_lower_norm'], subset['y_upper_norm'], alpha=0.2,
                         color=palette_pre[category])
axes[2].set_ylabel('normalized return (smoothed)')
axes[2].set_xlabel('training steps')
handles3, labels3 = axes[2].get_legend_handles_labels()
axes[2].legend(handles=handles3, labels=["not part of observation", "part of observation"], title="absolute position")
title3 = axes[2].set_title(r'\textbf{(c)}', fontsize=13)
title3.set_position(np.array([-0.125, 0.99]))


axes[3] = sns.pointplot(data=df_abs_pos, x="number_shifts", y="percentage_in_goal", ax=axes[3],
                 hue="abs_pos_obs", palette=palette_pre)
axes[3].set_xlabel("number of shifts")
axes[3].set_ylabel("probability of reaching goal")
handles4, labels4 = axes[3].get_legend_handles_labels()
axes[3].legend(handles=handles4, labels=["not part of observation", "part of observation"], title="absolute position")
title4 = axes[3].set_title(r'\textbf{(d)}', fontsize=13)
title4.set_position(np.array([-0.125, 0.99]))

fig.figure.savefig("obs_plots.pdf", format="pdf", bbox_inches="tight")
plt.show()