import os
import logging
from collections import defaultdict
from typing import Dict, List

import numpy as np
import hydra
from omegaconf import OmegaConf, DictConfig
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
from seaborn.rcmod import axes_style

from inference_rlhf.code.plot_hardness import get_minimum_hardness_to_avg_samples_to_get_correct
from inference_rlhf.code.helpers.io import json_dump, json_load

log = logging.getLogger(__name__)

DELTA = 0.1

# Phi-4 on game24
PHI_4_GAME24_ELLIPTICAL_PATH = "anonymous/anonymous/inference-rlhf/figures/game24/phi-4/elliptical_samples_to_get_correct_temp_1.0_top_p_1.0_lamb_1.0_sparse_dim_512_elliptical_feature_phi-4_mean_hidden_state_center_features.json"
PHI_4_GAME24_RANDOM_PATH = "anonymous/anonymous/inference-rlhf/figures/game24/phi-4/vanilla_samples_to_get_correct_temp_1.0_top_p_1.0.json"
PHI_4_GAME24_REF_PATH = "anonymous/anonymous/inference-rlhf/figures/game24/phi-4/vanilla_samples_to_get_correct_ref_gpt-4o-mini_temp_1.0_top_p_1.0.json"

# Qwen 14B on math
QWEN_14B_MATH_ELLIPTICAL_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-14b/elliptical_samples_to_get_correct_temp_1.0_top_p_1.0_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_qwen-25-14b_mean_hidden_state_center_features.json"
QWEN_14B_MATH_RANDOM_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-14b/vanilla_samples_to_get_correct_temp_1.0_top_p_1.0_min_p_0.0.json"
QWEN_14B_MATH_REF_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-14b/vanilla_samples_to_get_correct_ref_gpt-4o-mini_temp_1.0_top_p_1.0_min_p_0.0.json"

# Qwen 7B on gsm8k
QWEN_7B_GSM8K_ELLIPTICAL_PATH = "anonymous/anonymous/inference-rlhf/figures/gsm8k/qwen-25-7b/elliptical_samples_to_get_correct_temp_1.0_top_p_1.0_lamb_1.0_sparse_dim_512_elliptical_feature_qwen-25-7b_mean_hidden_state_center_features.json"
QWEN_7B_GSM8K_RANDOM_PATH = "anonymous/anonymous/inference-rlhf/figures/gsm8k/qwen-25-7b/vanilla_samples_to_get_correct_temp_1.0_top_p_1.0.json"
QWEN_7B_GSM8K_REF_PATH = "anonymous/anonymous/inference-rlhf/figures/gsm8k/qwen-25-7b/vanilla_samples_to_get_correct_ref_gpt-4o-mini_temp_1.0_top_p_1.0.json"

# FACE_COLOR = "black"
FACE_COLOR = "#F7F7FF"
MARKER = "o"
LINEWIDTH = 1.7
MARKERSIZE = 8
MARKEREDGEWIDTH = 1.2

LEGEND_FONT_SIZE = 14
TITLE_FONT_SIZE = 14
XLABEL_FONT_SIZE = 14
YLABEL_FONT_SIZE = 14
TICK_LABEL_FONT_SIZE = 14


def plot_method_data(avg_data, sem_data, label: str, delta: float, ax):
    xs = np.array(list(avg_data.keys()))
    ys = np.array([avg_data[x] for x in xs])
    sems = np.array([sem_data[x] for x in xs])
        
    ax.errorbar(
        xs + delta / 2,
        ys,
        yerr=sems,
        fmt=MARKER + "-",
        markersize=MARKERSIZE,
        linewidth=LINEWIDTH,
        markeredgewidth=MARKEREDGEWIDTH,
        markeredgecolor=FACE_COLOR,
        label=label,
        capsize=5,
        elinewidth=1,  # Make error bar lines smaller
        capthick=1,     # Make error bar cap lines smaller
        # color=METHOD_TO_COLOR[label]
    )
   
    # Fill between for each bin so the shaded area spans the full bin width
    for x, y, sem in zip(xs, ys, sems):
        ax.fill_between(
            [x, x + delta],
            [y - sem, y - sem],
            [y + sem, y + sem],
            alpha=0.1,
            color=ax.lines[-1].get_color()
            # color=METHOD_TO_COLOR[label]
        )

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    phi_4_elliptical_data = json_load(PHI_4_GAME24_ELLIPTICAL_PATH)
    phi_4_random_data = json_load(PHI_4_GAME24_RANDOM_PATH)
    phi_4_ref_data = json_load(PHI_4_GAME24_REF_PATH)

    qwen_14b_elliptical_data = json_load(QWEN_14B_MATH_ELLIPTICAL_PATH)
    qwen_14b_random_data = json_load(QWEN_14B_MATH_RANDOM_PATH)
    qwen_14b_ref_data = json_load(QWEN_14B_MATH_REF_PATH)

    qwen_7b_elliptical_data = json_load(QWEN_7B_GSM8K_ELLIPTICAL_PATH)
    qwen_7b_random_data = json_load(QWEN_7B_GSM8K_RANDOM_PATH)
    qwen_7b_ref_data = json_load(QWEN_7B_GSM8K_REF_PATH)

    # post-process elliptical data
    phi_4_elliptical_data = {int(k): np.mean(v) for k, v in phi_4_elliptical_data.items()}
    qwen_14b_elliptical_data = {int(k): np.mean(v) for k, v in qwen_14b_elliptical_data.items()}
    qwen_7b_elliptical_data = {int(k): np.mean(v) for k, v in qwen_7b_elliptical_data.items()}

    # post-process random data
    phi_4_random_data = {int(k): v for k, v in phi_4_random_data.items()}
    qwen_14b_random_data = {int(k): v for k, v in qwen_14b_random_data.items()}
    qwen_7b_random_data = {int(k): v for k, v in qwen_7b_random_data.items()}

    # post-process ref data
    phi_4_ref_data = {int(k): v for k, v in phi_4_ref_data.items()}
    qwen_14b_ref_data = {int(k): v for k, v in qwen_14b_ref_data.items()}
    qwen_7b_ref_data = {int(k): v for k, v in qwen_7b_ref_data.items()}

    all_data = [
        (phi_4_elliptical_data, phi_4_random_data, phi_4_ref_data, "Phi-4 on Game of 24"),
        (qwen_14b_elliptical_data, qwen_14b_random_data, qwen_14b_ref_data, "Qwen-2.5-14B-Instruct on MATH"),
        (qwen_7b_elliptical_data, qwen_7b_random_data, qwen_7b_ref_data, "Qwen-2.5-7B-Instruct on GSM8K"),
    ]

    sns.set_theme(style="whitegrid")
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for i, (elliptical_data, random_data, ref_data, title) in enumerate(all_data):
        ax = axes[i]
        plt.sca(ax)  # Set current axis
        
        elliptical_bin_to_avg_samples_to_correct, elliptical_bin_to_sem_samples_to_get_correct = get_minimum_hardness_to_avg_samples_to_get_correct(
            elliptical_data, ref_data, hardness_style="quantile_num_to_correct", delta=DELTA
        )
        random_bin_to_avg_samples_to_correct, random_bin_to_sem_samples_to_get_correct = get_minimum_hardness_to_avg_samples_to_get_correct(
            random_data, ref_data, hardness_style="quantile_num_to_correct", delta=DELTA
        )

        plot_method_data(elliptical_bin_to_avg_samples_to_correct, elliptical_bin_to_sem_samples_to_get_correct, "Elliptical", DELTA, plt)
        plot_method_data(random_bin_to_avg_samples_to_correct, random_bin_to_sem_samples_to_get_correct, "Random", DELTA, plt)

        # Specify plot details
        ax.set_title(f"{title}", fontweight="bold", fontsize=TITLE_FONT_SIZE)
        ax.set_xlabel('Hardness quantile (%)', fontsize=XLABEL_FONT_SIZE)
        if i == 0:  # Only set ylabel for the first subplot
            ax.set_ylabel('Samples-to-correct', fontsize=YLABEL_FONT_SIZE)
        ax.legend(fontsize=LEGEND_FONT_SIZE)
        ax.set_xlim(0, 1)
        ax.set_ylim(bottom=0)
        ax.set_xticks(np.arange(0, 1.0 + DELTA, DELTA))
        ax.set_xticklabels(['0', '10', '20', '30', '40', '50', '60', '70', '80', '90', '100'])

        for label in ax.get_xticklabels():
            label.set_fontsize(TICK_LABEL_FONT_SIZE)
        for label in ax.get_yticklabels():
            label.set_fontsize(TICK_LABEL_FONT_SIZE)

    plt.tight_layout()
    plt.savefig(os.path.join("figures", "quantile_hardness_all.pdf"), bbox_inches='tight')

if __name__ == "__main__":
    main()