# summarizing plot
import pandas as pd

# Reload the updated dataframe
df_updated_full = pd.read_csv("output/cix_n_samples_test_errors_mean_std.csv")

# List unique loss functions present in the dataframe
unique_loss_functions = df_updated_full['loss_function'].unique()
unique_loss_functions
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# Set plot style and define distinct markers and colors
sns.set_style("whitegrid")

# Remove entries corresponding to the WGAN loss function
df_no_wgan = df_updated_full[df_updated_full['loss_function'] != 'WGAN']

# Set plot style and define distinct markers and colors
plt.figure(figsize=(12, 10))
markers = ['o', 's', '^', 'D', '*', 'p', 'v', '<', '>']
colors = sns.color_palette("tab10", len(df_no_wgan['loss_function'].unique()))

# Set plot style and define distinct markers and colors
plt.figure(figsize=(12, 10))

# Plot for mse_f_mean without WGAN
plt.subplot(2, 1, 1)
for idx, loss_func in enumerate(df_no_wgan['loss_function'].unique()):
    subset = df_no_wgan[df_no_wgan['loss_function'] == loss_func]
    plt.plot(subset['train_n_samples'], subset['mse_f_mean'], label=loss_func, 
             marker=markers[idx], color=colors[idx], markersize=6, linestyle='-')
    plt.fill_between(subset['train_n_samples'], 
                     subset['mse_f_mean'] - subset['mse_f_std'], 
                     subset['mse_f_mean'] + subset['mse_f_std'], color=colors[idx], alpha=0.2)
plt.title('Comparison of Deterministic Driving Force Errors Without WGAN')
plt.xlabel('Number of Training Samples')
plt.ylabel('Mean Error in Deterministic Driving Force (mse_f)')
plt.legend()
plt.xscale('log')
# Plot for mse_σ_mean without WGAN and create centered inset
plt.subplot(2, 1, 2)
for idx, loss_func in enumerate(df_no_wgan['loss_function'].unique()):
    subset = df_no_wgan[df_no_wgan['loss_function'] == loss_func]
    plt.plot(subset['train_n_samples'], subset['mse_σ_mean'], label=loss_func, 
             marker=markers[idx], color=colors[idx], markersize=6, linestyle='-')
    plt.fill_between(subset['train_n_samples'], 
                     subset['mse_σ_mean'] - subset['mse_σ_std'], 
                     subset['mse_σ_mean'] + subset['mse_σ_std'], color=colors[idx], alpha=0.2)
plt.xscale('log')

plt.title('Comparison of Diffusion Errors Without WGAN')
plt.xlabel('Number of Training Samples')
plt.ylabel('Mean Error in Diffusion (mse_σ)')
plt.legend()

# Centered inset plot with y-log scale
axins = inset_axes(plt.gca(), width="30%", height="30%", loc='center')
for idx, loss_func in enumerate(df_no_wgan['loss_function'].unique()):
    subset = df_no_wgan[df_no_wgan['loss_function'] == loss_func]
    axins.plot(subset['train_n_samples'], subset['mse_σ_mean'],
               marker=markers[idx], color=colors[idx], markersize=4, linestyle='-')
    axins.fill_between(subset['train_n_samples'], 
                     subset['mse_σ_mean'] - subset['mse_σ_std'], 
                     subset['mse_σ_mean'] + subset['mse_σ_std'], color=colors[idx], alpha=0.2)
axins.set_yscale('log')
axins.set_xscale('log')

plt.tight_layout()
plt.savefig("output/plot_comparison_no_wgan_inset_centered.png", format="png")
plt.savefig("output/plot_comparison_no_wgan_inset_centered.svg", format="svg")
plt.savefig("output/plot_comparison_no_wgan_inset_centered.pdf", format="pdf")
plt.show()
