import os
import json
from typing import Dict, List, Tuple
from collections import defaultdict

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import bootstrap, ttest_ind
from matplotlib.lines import Line2D

from inference_rlhf.code.helpers.io import json_load
from inference_rlhf.code.helpers.utils import rget_json_files_from_dir
from inference_rlhf.code.plot_teaser import get_x_task_model, get_y_task_model, MODEL_TO_COLOR, TASK_TO_NICE_NAME, TASK_TO_SHAPE, get_y_label, get_x_label, create_task_handles, MODEL_INFO_SORTED, create_model_handles

DATA_DIR = "./data"
FIGURES_DIR = "./figures"
INCLUDE_ERROR_BARS = False
HIGHLIGHT_SIGNIFICANT = False
X_LOG_SCALE = False
Y_LOG_SCALE = False
TASKS = ["gsm8k", "mbpp", "aime_2025", "game24", "math"]
GEN_TYPE = "vanilla"

OUTER_LEGEND_FONT_SIZE = 10
XLABEL_FONT_SIZE = 12
YLABEL_FONT_SIZE = 12
TICK_LABEL_FONT_SIZE = 12

MODEL_TO_MARKER_SIZE = {
    # Llama
    "llama-3-3b": 80,  # 3B model
    "llama-3-8b": 80,  # 8B model
    # Phi
    "phi-4": 80,       # 14B model
    "phi-3-medium": 80, # 14B model
    # Qwen
    "qwen-25-05b": 80,  # 0.5B model
    "qwen-25-3b": 80,  # 3B model
    "qwen-25-7b": 80,  # 7B model
    "qwen-25-14b": 80, # 14B model
    "qwen-25-32b": 80, # 32B model
    # Mistral
    "mistral-7b": 80,  # 7B model
}

def main():
    # Make sure p@1s are computed
    get_x_task_model("p_at_1", remove_weak_models=False)

    # Get x values
    x_task_model = get_x_task_model("rank_p_at_1", remove_weak_models=False)

    # Get y values
    y_task_model = get_y_task_model("perc_improvement", remove_weak_models=False)

    # Extract data for plotting
    all_xs = defaultdict(dict)
    all_ys = defaultdict(dict)
    all_ys_low = defaultdict(dict)
    all_ys_high = defaultdict(dict)
    all_p_values = defaultdict(dict)
    for task in y_task_model.keys():
        for model in y_task_model[task].keys():
            all_xs[task][model] = x_task_model[task][model]
            all_ys[task][model] = y_task_model[task][model][0]
            all_ys_low[task][model] = y_task_model[task][model][1]
            all_ys_high[task][model] = y_task_model[task][model][2]
            if HIGHLIGHT_SIGNIFICANT:
                all_p_values[task][model] = y_task_model[task][model][3]
            else:
                all_p_values[task][model] = None

    sns.set_style("whitegrid")
    fig, ax = plt.subplots(figsize=(8, 5))
    for task in all_xs.keys():
        for model in all_xs[task].keys():
            plt.scatter(
                all_xs[task][model], 
                all_ys[task][model], 
                marker=TASK_TO_SHAPE[task], 
                color=MODEL_TO_COLOR[model], 
                s=MODEL_TO_MARKER_SIZE[model], 
                edgecolor="black" if all_p_values[task][model] is None else "red" if all_p_values[task][model] < 0.05 else "black",
            )

    plt.axhline(y=0, color='black', linestyle='--', linewidth=1.0)
    ax.set_xlim(-0.5, 8.5)
    ax.set_ylim(-175, 100)

    # color background green above the line, and red below the line
    fill_x = np.linspace(-1, 9, 500)
    y_max = ax.get_ylim()[1]
    y_min = ax.get_ylim()[0]

    # Green below the line
    ax.fill_between(
        fill_x, 0, y_max,
        color='green', alpha=0.08, zorder=-2
    )
    # Red above the line
    ax.fill_between(
        fill_x, y_min, 0,
        color='red', alpha=0.08, zorder=-2
    )

    # add leftpointing arrow at -100, 4
    ax.annotate(
        '',
        xy=(6.5, -125),    # arrow tip (right of text)
        xytext=(4.5, -125),  # arrow tail (at text)
        arrowprops=dict(arrowstyle='->', color='black', lw=1.5)
    )
    # Center the text above the arrow
    arrow_center_x_stronger = (4.5 + 6.5) / 2
    ax.text(
        arrow_center_x_stronger, -117,
        "Stronger models",
        fontsize=10,
        fontweight="bold",
        alpha=0.9,
        ha="center"
    )

    # add leftpointing arrow at -100, 4
    ax.annotate(
        '',
        xy=(1.5, -125),    # arrow tip (left of text)
        xytext=(3.5, -125),  # arrow tail (at text)
        arrowprops=dict(arrowstyle='->', color='black', lw=1.5)
    )
    # Center the text above the arrow
    arrow_center_x = (1.5 + 3.5) / 2
    ax.text(
        arrow_center_x, -117,
        "Weaker models",
        fontsize=10,
        fontweight="bold",
        alpha=0.9,
        ha="center"
    )

    task_handles = create_task_handles(MODEL_TO_MARKER_SIZE)

    # Add task legend
    legend_1 = ax.legend(
        handles=task_handles, 
        title="Task", 
        loc="lower left", 
        bbox_to_anchor=(1.02, -0.1), 
        borderaxespad=0.0,
        fontsize=OUTER_LEGEND_FONT_SIZE, 
        title_fontsize=OUTER_LEGEND_FONT_SIZE + 2
    )
    ax.add_artist(legend_1)

    model_handles = create_model_handles(MODEL_INFO_SORTED, MODEL_TO_MARKER_SIZE)
    legend_2 = ax.legend(
        handles=model_handles, 
        title="Model", 
        loc="upper left", 
        bbox_to_anchor=(1.02, 1.0), 
        borderaxespad=0.0,
        fontsize=OUTER_LEGEND_FONT_SIZE, 
        title_fontsize=OUTER_LEGEND_FONT_SIZE + 2
    )
    ax.add_artist(legend_2)

    xlabel = get_x_label("rank_p_at_1")
    ylabel = get_y_label("perc_improvement")
    plt.xlabel(xlabel, fontsize=XLABEL_FONT_SIZE)
    plt.ylabel(ylabel, fontsize=YLABEL_FONT_SIZE)

    for label in ax.get_xticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)
    for label in ax.get_yticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)

    save_file_name = f"perc_improvements_vs_rank_p_at_1.pdf"
    bbox_artists = [legend_1, legend_2]
    plt.savefig(
        os.path.join(FIGURES_DIR, save_file_name),
        bbox_inches='tight',
        bbox_extra_artists=tuple(bbox_artists)
    )

if __name__ == "__main__":
    main()