import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# Example data
results_avalon = {
    'GPT4.0':
[7.25,
5.75,
4.125,
2.857142857,
2,
1.75,
1.444444444,
1.428571429,
1.363636364,
1.3,
0.714285714,
0.5,
0.166666667,
0.076923077,
-0.416666667,
-0.5,
-0.666666667,
-0.714285714,
-0.846153846,
-1,
-1.142857143,
-1.6,
-2.25,
-3.625,
-3.875,
-4.166666667,
-4.75,
-5.75,
-6.333333333,
-7.25],
    'GPT3.5':
[5.833333333,
4.090909091,
4,
3.875,
2.444444444,
2.222222222,
1.777777778,
1.5,
1.2,
1.090909091,
0.222222222,
0.142857143,
0.058823529,
-0.090909091,
-0.142857143,
-0.5,
-0.571428571,
-0.75,
-0.8,
-1,
-1.2,
-1.222222222,
-1.3,
-1.5,
-1.5,
-1.875,
-2.222222222,
-2.25,
-6.090909091,
-8.833333333,
],
}

#
# means_avalon = {'RL-Training': np.mean(results_avalon['RL-Training']), 'LLM-Improvement': np.mean(results_avalon['LLM-Improvement']), 'Random Rollout': np.mean(results_avalon['Random Rollout'])}
# stds_avalon = {'RL-Training': np.std(results_avalon['RL-Training']), 'LLM-Improvement': np.std(results_avalon['LLM-Improvement']), 'Random Rollout': np.std(results_avalon['Random Rollout'])}
# print('means_avalon = ', means_avalon)
# print('stds_avalon = ', stds_avalon)
# print()


# Modified prepare_data function
def prepare_data(results):
    data = []
    labels = []
    hue_labels = []
    for key, values in results.items():
        if key != 'Random Rollout':
            data.extend(values)
            labels.extend(['Data'] * len(values))
            hue_labels.extend([key] * len(values))
    return pd.DataFrame({'Data': data, 'Category': labels, 'Type': hue_labels})

# Define custom palette
palette = {'GPT4.0': 'LightBlue', 'GPT3.5': 'LightGreen'}

# Create DataFrames
df_avalon = prepare_data(results_avalon)

# Plot for Avalon
plt.figure(figsize=(12, 9))
sns.boxplot(x='Category', y='Data', hue='Type', data=df_avalon, palette=palette, width=0.1, linecolor='black', linewidth=6)
# plt.axhline(y=results_avalon['Random Rollout'][0], color='red', linestyle='--', label='Random Rollout', linewidth=6)
plt.xlabel('Category', fontsize=28)
plt.ylabel('Score', fontsize=28)
plt.xticks([])  # Remove x-axis ticks
plt.yticks(fontsize=28)
plt.xlim(-0.15, 0.15)
plt.legend(loc='upper left', fontsize=22)
plt.savefig('gpt_4_35.png')
