
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
import seaborn as sns

from inference_rlhf.code.helpers.io import json_load

LABEL_FONT_SIZE = 15
TITLE_FONT_SIZE = 16
TICK_LABEL_FONT_SIZE = 14

# MEAN_PATH = "anonymous/anonymous/inference-rlhf/figures/mbpp/phi-4/elliptical_samples_to_get_correct_temp_1.0_top_p_0.95_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_phi-4_mean_hidden_state_center_features.json"
# LAST_HIDDEN_PATH = "anonymous/anonymous/inference-rlhf/figures/mbpp/phi-4/ablation/elliptical_samples_to_get_correct_temp_1.0_top_p_0.95_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_phi-4_last_hidden_state_center_features.json"
# SECOND_TO_LAST_HIDDEN_PATH = "anonymous/anonymous/inference-rlhf/figures/mbpp/phi-4/ablation/elliptical_samples_to_get_correct_temp_1.0_top_p_0.95_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_phi-4_second_to_last_hidden_state_center_features.json"
# VANILLA_PATH = "anonymous/anonymous/inference-rlhf/figures/mbpp/phi-4/vanilla_samples_to_get_correct_temp_1.0_top_p_0.95_min_p_0.0.json"

MEAN_PATH = "anonymous/anonymous/inference-rlhf/figures/game24/phi-4/safe/elliptical_samples_to_get_correct_temp_1.0_top_p_1.0_lamb_1.0_sparse_dim_512_elliptical_feature_phi-4_mean_hidden_state_center_features.json"
LAST_HIDDEN_PATH = "anonymous/anonymous/inference-rlhf/figures/game24/phi-4/elliptical_samples_to_get_correct_temp_1.0_top_p_1.0_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_phi-4_last_hidden_state_center_features.json"
SECOND_TO_LAST_HIDDEN_PATH = "anonymous/anonymous/inference-rlhf/figures/game24/phi-4/elliptical_samples_to_get_correct_temp_1.0_top_p_1.0_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_phi-4_second_to_last_hidden_state_center_features.json"
VANILLA_PATH = "anonymous/anonymous/inference-rlhf/figures/game24/phi-4/vanilla_samples_to_get_correct_temp_1.0_top_p_1.0_min_p_0.0.json"

def main():
    mean_data = json_load(MEAN_PATH)
    last_hidden_data = json_load(LAST_HIDDEN_PATH)
    second_to_last_hidden_data = json_load(SECOND_TO_LAST_HIDDEN_PATH)
    vanilla_data = json_load(VANILLA_PATH)

    mean_mean = {int(k): np.mean(v) for k, v in mean_data.items()}
    mean_sem = stats.sem(list(mean_mean.values()))
    mean_mean = np.mean(list(mean_mean.values()))

    last_hidden_mean = {int(k): np.mean(v) for k, v in last_hidden_data.items()}
    last_hidden_sem = stats.sem(list(last_hidden_mean.values()))
    last_hidden_mean = np.mean(list(last_hidden_mean.values()))

    second_to_last_hidden_mean = {int(k): np.mean(v) for k, v in second_to_last_hidden_data.items()}
    second_to_last_hidden_sem = stats.sem(list(second_to_last_hidden_mean.values()))
    second_to_last_hidden_mean = np.mean(list(second_to_last_hidden_mean.values()))

    vanilla_mean = {int(k): np.mean(v) for k, v in vanilla_data.items()}
    vanilla_sem = stats.sem(list(vanilla_mean.values()))
    vanilla_mean = np.mean(list(vanilla_mean.values()))

    # bar plot
    sns.set_theme(style="whitegrid")
    bar_labels = ['Mean', 'Last', 'Second-to-last']
    bar_means = [mean_mean, last_hidden_mean, second_to_last_hidden_mean]
    bar_sems = [mean_sem, last_hidden_sem, second_to_last_hidden_sem]
    # Use a more visually appealing color palette from seaborn
    bar_colors = sns.color_palette("colorblind", n_colors=3)

    plt.title('Phi-4 on Game of 24', fontsize=TITLE_FONT_SIZE)
    plt.bar(
        bar_labels,
        bar_means,
        yerr=bar_sems,
        capsize=5,
        color=bar_colors
    )
    for label in plt.gca().get_xticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)
    for label in plt.gca().get_yticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)
    plt.xlabel('Representation type', fontsize=LABEL_FONT_SIZE)
    plt.ylabel('Average samples-to-correct', fontsize=LABEL_FONT_SIZE)
    plt.tight_layout()
    plt.savefig('ablation.pdf')

if __name__ == "__main__":
    main()