import os
import json
from typing import Dict, List, Tuple
from collections import defaultdict

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import bootstrap, ttest_ind
from matplotlib.lines import Line2D
from matplotlib.patches import FancyArrowPatch, ConnectionPatch
from scipy.stats import linregress
from matplotlib.ticker import LogLocator, NullFormatter

from inference_rlhf.code.helpers.io import json_load
from inference_rlhf.code.helpers.utils import rget_json_files_from_dir
from inference_rlhf.code.plot_hardness import get_minimum_hardness_to_avg_samples_to_get_correct
from inference_rlhf.code.plot_all_hardness import plot_method_data

DATA_DIR = "./data"
FIGURES_DIR = "./figures"

PLOT_TYPE = "teaser"
QUANTILE_PERC = 1.0
REF_MODEL = "gpt-4o-mini"
X_AXIS_TYPE = "samples_to_correct" #"samples_to_correct" #"p_at_1" #"p_at_1" # "p_at_1_quantile" "rank_p_at_1"
Y_AXIS_TYPE = "samples_to_correct" #"samples_to_correct" #"log_auc_perc_improvement" #"auc_perc_improvement" #"samples_to_correct_diff" #"perc_improvement" #"rank"
INCLUDE_ERROR_BARS = False
HIGHLIGHT_SIGNIFICANT = False
X_LOG_SCALE = True
Y_LOG_SCALE = True
REMOVE_WEAK_MODELS = True
TASKS = ["gsm8k", "mbpp", "aime_2025", "game24", "math"]
GEN_TYPE = "vanilla"
WEAK_MODELS = ["qwen-25-05b", "mistral-7b", "llama-3-3b"]
XLABEL_FONT_SIZE = 18
YLABEL_FONT_SIZE = 18
OUTER_LEGEND_FONT_SIZE = 16
TICK_LABEL_FONT_SIZE = 15

if PLOT_TYPE == "model_strength_plot":
    X_AXIS_TYPE = "rank_p_at_1"
    Y_AXIS_TYPE = "perc_improvement"
    X_LOG_SCALE = False
    Y_LOG_SCALE = False
    REMOVE_WEAK_MODELS = False

    DELTA = 0.1

    # Phi-4 on game24
    PHI_4_GAME24_ELLIPTICAL_PATH = "anonymous/anonymous/inference-rlhf/figures/game24/phi-4/elliptical_samples_to_get_correct_temp_1.0_top_p_1.0_lamb_1.0_sparse_dim_512_elliptical_feature_phi-4_mean_hidden_state_center_features.json"
    PHI_4_GAME24_RANDOM_PATH = "anonymous/anonymous/inference-rlhf/figures/game24/phi-4/vanilla_samples_to_get_correct_temp_1.0_top_p_1.0.json"
    PHI_4_GAME24_REF_PATH = "anonymous/anonymous/inference-rlhf/figures/game24/phi-4/vanilla_samples_to_get_correct_ref_gpt-4o-mini_temp_1.0_top_p_1.0.json"

    # Qwen 14B on math
    QWEN_14B_MATH_ELLIPTICAL_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-14b/elliptical_samples_to_get_correct_temp_1.0_top_p_1.0_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_qwen-25-14b_mean_hidden_state_center_features.json"
    QWEN_14B_MATH_RANDOM_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-14b/vanilla_samples_to_get_correct_temp_1.0_top_p_1.0_min_p_0.0.json"
    QWEN_14B_MATH_REF_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-14b/vanilla_samples_to_get_correct_ref_gpt-4o-mini_temp_1.0_top_p_1.0_min_p_0.0.json"

    HARDNESS_LEGEND_FONT_SIZE = 14
    HARDNESS_TITLE_FONT_SIZE = 14


TASK_GEN_TYPE_TO_TEMP = {
    "math": {
        "vanilla": 1.0,
        "low_temp": 0.6,
        "high_temp": 1.5,
        "min_p": 1.5,
        "nucleus": 1.0,
    },
    "gsm8k": {
        "vanilla": 1.0,
    },
    "mbpp": {
        "vanilla": 1.0,
    },
    "aime_2025": {
        "vanilla": 1.0,
    },
    "game24": {
        "vanilla": 1.0,
    }
}

TASK_GEN_TYPE_TO_TOP_P = {
    "math": {
        "vanilla": 1.0,
        "low_temp": 1.0,
        "high_temp": 1.0,
        "min_p": 1.0,
        "nucleus": 0.90,
    },
    "gsm8k": {
        "vanilla": 1.0,
    },
    "mbpp": {
        "vanilla": 0.95
    },
    "aime_2025": {
        "vanilla": 1.0,
    },
    "game24": {
        "vanilla": 1.0,
    }
}

TASK_GEN_TYPE_TO_MIN_P = {
    "math": {
        "vanilla": 0.0,
        "low_temp": 0.0,
        "high_temp": 0.0,
        "min_p": 0.05,
        "nucleus": 0.0,
    },
    "gsm8k": {
        "vanilla": 0.0,
    },
    "mbpp": {
        "vanilla": 0.0,
    },
    "aime_2025": {
        "vanilla": 0.0,
    },
    "game24": {
        "vanilla": 0.0,
    }
}

TASK_TO_SHAPE = {
    "math": "o",
    "gsm8k": "s",
    "mbpp": "^",
    "aime_2025": "D",
    "game24": "P"
}

TASK_TO_NICE_NAME = {
    "math": "MATH",
    "gsm8k": "GSM8K",
    "mbpp": "MBPP+",
    "aime_2025": "AIME 2025",
    "game24": "Game of 24",
}

MODEL_TO_COLOR = {
    # Llama
    "llama-3-3b": "#FF8A8A",  # lighter coral red
    "llama-3-8b": "#E55A5A",  # darker coral red
    # Phi
    "phi-3-medium": "#FFB366",  # light orange
    "phi-4": "#FF8C00",         # dark orange
    # Qwen
    "qwen-25-05b": sns.color_palette("viridis", 4)[3], # lightest
    "qwen-25-3b": sns.color_palette("viridis", 4)[2],  # medium light
    "qwen-25-7b": sns.color_palette("viridis", 4)[1],  # medium dark
    "qwen-25-14b": sns.color_palette("viridis", 4)[0], # darkest
    "qwen-25-32b": "black", # darkest
    # Mistral
    "mistral-7b": "#BB8FCE",  # lavender
}

MODEL_TO_MARKER_SIZE = {
    # Llama
    "llama-3-3b": 200,  # 3B model
    "llama-3-8b": 200,  # 8B model
    # Phi
    "phi-4": 200,       # 14B model
    "phi-3-medium": 200, # 14B model
    # Qwen
    "qwen-25-05b": 200,  # 0.5B model
    "qwen-25-3b": 200,  # 3B model
    "qwen-25-7b": 200,  # 7B model
    "qwen-25-14b": 200, # 14B model
    "qwen-25-32b": 200, # 32B model
    # Mistral
    "mistral-7b": 200,  # 7B model
}

MODEL_TO_LEGEND_LABEL = {
    # Llama
    "llama-3-3b": "Llama 3B",
    "llama-3-8b": "Llama 8B",
    # Phi
    "phi-4": "Phi 4",
    "phi-3-medium": "Phi 3 medium",
    # Qwen
    "qwen-25-05b": "Qwen 0.5B",
    "qwen-25-3b": "Qwen 3B",
    "qwen-25-7b": "Qwen 7B",
    "qwen-25-14b": "Qwen 14B",
    "qwen-25-32b": "Qwen 32B",
    # Mistral
    "mistral-7b": "Mistral 7B",
}

MODEL_INFO_SORTED = [
    ('llama', 3, 'llama-3-3b'),
    ('llama', 8, 'llama-3-8b'),
    ('phi', 14, 'phi-3-medium'),
    ('phi', 14, 'phi-4'),
    ('qwen', 0.5, 'qwen-25-05b'),
    ('qwen', 3, 'qwen-25-3b'),
    ('qwen', 7, 'qwen-25-7b'),
    ('qwen', 14, 'qwen-25-14b'),
    ('qwen', 32, 'qwen-25-32b'),
    ('mistral', 7, 'mistral-7b'),
]

import matplotlib.transforms as mtransforms
def draw_square_pixels(ax, x, y, size_px=20, **kwargs):
    # center in display (pixel) coords
    x_disp, y_disp = ax.transData.transform((x, y))
    h = size_px / 2.0
    # square corners in display coords
    corners_disp = [(x_disp - h, y_disp - h),
                    (x_disp + h, y_disp - h),
                    (x_disp + h, y_disp + h),
                    (x_disp - h, y_disp + h)]
    # back to data coords
    corners_data = ax.transData.inverted().transform(corners_disp)
    from matplotlib.patches import Polygon
    ax.add_patch(Polygon(corners_data, closed=True, fill=False,
                         edgecolor=kwargs.get("color", "black"),
                         linewidth=kwargs.get("linewidth", 1),
                         zorder=5, transform=ax.transData))

def get_pixel_square_bounds(ax, x, y, size_px=20):
    """Return (x_min, x_max, y_min, y_max) of a pixel-sized square centered at (x, y) in data coords."""
    x_disp, y_disp = ax.transData.transform((x, y))
    h = size_px / 2.0
    (x_min, y_min), (x_max, y_max) = ax.transData.inverted().transform(
        [(x_disp - h, y_disp - h), (x_disp + h, y_disp + h)]
    )
    return x_min, x_max, y_min, y_max

def draw_log_square(ax, x, y, size=0.1, **kwargs):
    """
    Draw a square around the marker at (x, y).
    The square is sized in log10 space (for log-log plots).
    """
    import matplotlib.transforms as mtransforms
    # size is in log10 units
    trans = mtransforms.blended_transform_factory(ax.transData, ax.transData)
    rect = plt.Rectangle(
        (x * 10**(-size/2), y * 10**(-size/2)),
        x * (10**(size/2) - 10**(-size/2)),
        y * (10**(size/2) - 10**(-size/2)),
        fill=False,
        edgecolor=kwargs.get("color", "black"),
        linewidth=1,
        zorder=5,
        transform=trans
    )
    ax.add_patch(rect)

def load_json_files(load_path: str, task: str) -> List[str]:
    """
    Load all json files from load path.
    """
    json_files = rget_json_files_from_dir(load_path)

    # Filter out LLM direct coreset files
    json_files = [gf for gf in json_files if not 'coreset' in gf]

    # Filter out eval results MBPP files
    json_files = [gf for gf in json_files if not 'eval_results' in gf]

    # Filter based on temperature
    json_files = [
        gf for gf in json_files 
        if f'--temp-{TASK_GEN_TYPE_TO_TEMP[task][GEN_TYPE]}-' in gf \
            or f'temp_{TASK_GEN_TYPE_TO_TEMP[task][GEN_TYPE]}-' in gf
    ]

    # Filter based on top-p
    json_files = [
        gf for gf in json_files 
        if f'--top-p-{TASK_GEN_TYPE_TO_TOP_P[task][GEN_TYPE]}' in gf
    ]

    # Filter based on min-p
    json_files = [
        gf for gf in json_files
        if not '--min-p' in gf
    ]

    if task == 'mbpp':
        json_files = [gf for gf in json_files if '--CHECKED' in gf]

    # Assert correct len
    # if task == 'mbpp':
    #     assert len(json_files) == 378
    if task == 'math':
        assert len(json_files) == 5000
    elif task == 'gsm8k':
        assert len(json_files) == 1319
    elif task == 'aime_2025':
        assert len(json_files) == 30
    elif task == 'game24':
        assert len(json_files) == 1362

    return json_files

def build_baseline_file_name(task: str, ref: bool) -> str:
    baseline_file_name = "vanilla_samples_to_get_correct"
    if ref:
        baseline_file_name += f"_ref_{REF_MODEL}"
    baseline_file_name += f"_temp_{TASK_GEN_TYPE_TO_TEMP[task][GEN_TYPE]}"
    baseline_file_name += f"_top_p_{TASK_GEN_TYPE_TO_TOP_P[task][GEN_TYPE]}"
    if task in ['math', 'mbpp']: # TODO: add other tasks
        baseline_file_name += f"_min_p_{TASK_GEN_TYPE_TO_MIN_P[task][GEN_TYPE]}"
    baseline_file_name = f"{baseline_file_name}.json"

    return baseline_file_name

def build_elliptical_file_name(task: str, model) -> str:
    files = os.listdir(os.path.join(FIGURES_DIR, task, model))
    prefix = "elliptical_samples_to_get_correct"
    elliptical_files = [f for f in files if f.startswith(prefix)]

    if task in ['math', 'mbpp']: # TODO: add other tasks
        elliptical_file = [f for f in elliptical_files if f'temp_{TASK_GEN_TYPE_TO_TEMP[task][GEN_TYPE]}_top_p_{TASK_GEN_TYPE_TO_TOP_P[task][GEN_TYPE]}_min_p_{TASK_GEN_TYPE_TO_MIN_P[task][GEN_TYPE]}' in f][0]
    else:
        elliptical_file = [f for f in elliptical_files if 'temp_1.0_top_p_1.0' in f][0]
    
    return elliptical_file

def get_y_label(y_axis_type: str) -> str:
    ylabel = None
    if y_axis_type == "perc_improvement":
        ylabel = "% improvement over random"
    elif y_axis_type == "samples_to_correct":
        ylabel = "RepExp samples-to-correct"
    elif y_axis_type == "samples_to_correct_diff":
        ylabel = "Samples saved by RepExp"
    elif y_axis_type == "auc_perc_improvement":
        ylabel = "% AUC improvement"
    elif y_axis_type == "log_auc_perc_improvement":
        ylabel = "% log AUC improvement"
    
    return ylabel

def get_x_label(x_axis_type: str) -> str:
    x_label = None
    if x_axis_type == "p_at_1":
        x_label = "Pass@1"
    elif x_axis_type == "rank_p_at_1":
        x_label = "Relative rank (based on pass@1)"
    elif x_axis_type == "samples_to_correct":
        x_label = "Random samples-to-correct"

    return x_label

def filter_models(models: List[str], remove_weak_models: bool = False) -> List[str]:
    """
    Filter out coder and gpt models.

    Args:
        models: List of models.

    Returns:
        filtered_models: List of filtered models.
    """
    filtered_models = []
    for model in models:
        if not 'coder' in model and not 'gpt' in model and \
            not (remove_weak_models and model in WEAK_MODELS):
            filtered_models.append(model)

    return filtered_models

def get_x_task_model(x_axis_type: str, remove_weak_models: bool = False) -> Dict[str, Dict[str, float]]:
    """
    Get x-axis data for each task and model.

    Args:
        x_axis_type: Type of x-axis data to collect.
            - "p_at_1": Pass@1 over all samples.
            - "p_at_1_quantile": Pass@1 over hardest quantile of samples.
            - "samples_to_correct": Samples to correct over all samples.
            - "samples_to_correct_quantile": Samples to correct over hardest quantile of samples.

    Returns:
        x_task_model: Dictionary of task-model pairs, where each value is the x-axis data.
    """
    # Build file name
    file_name = f"{x_axis_type}_task_model.json"
    if "quantile" in x_axis_type:
        file_name = file_name.replace('.json', f'_{QUANTILE_PERC}.json')
    
    # Return cached data if it exists
    full_path = os.path.join(DATA_DIR, file_name)
    if os.path.exists(full_path):
        print(f"Loading cached data from {full_path} ...")
        x_task_model = defaultdict(dict, json_load(full_path))
    else:
        x_task_model = defaultdict(dict)

    for task in TASKS:
        models = os.listdir(os.path.join(DATA_DIR, task))
        models = filter_models(models, remove_weak_models=remove_weak_models)
        for model in tqdm(models):

            if task in x_task_model and model in x_task_model[task]:
                print(f"Skipping task {task} for model {model} because it already exists")
                continue

            if x_axis_type == "samples_to_correct":
                # Load vanilla samples-to-correct for current model
                baseline_file_name = build_baseline_file_name(task, ref=False)
                file_path = os.path.join(FIGURES_DIR, task, model, baseline_file_name)
                vanilla_samples_to_correct = json_load(file_path)

                # Average samples-to-correct
                x_task_model[task][model] = np.mean(list(vanilla_samples_to_correct.values()))
                continue

            elif x_axis_type == "samples_to_correct_quantile":
                # Load vanilla samples-to-correct for currrent model
                baseline_file_name = build_baseline_file_name(task, ref=False)
                file_path = os.path.join(FIGURES_DIR, task, model, baseline_file_name)
                vanilla_samples_to_correct = json_load(file_path)

                # Load vanilla samples-to-correct for ref model
                baseline_file_name_ref = build_baseline_file_name(task, ref=True)
                file_path_ref = os.path.join(FIGURES_DIR, task, model, baseline_file_name_ref)
                vanilla_samples_to_correct_ref = json_load(file_path_ref)

                # Sort and only take hardest quantile as determined by ref model
                sorted_prompt_idxs = sorted(vanilla_samples_to_correct_ref.keys(), key=vanilla_samples_to_correct_ref.get)
                num_samples = len(sorted_prompt_idxs)
                num_samples_to_take = int(num_samples * QUANTILE_PERC)
                sorted_prompt_idxs = sorted_prompt_idxs[-num_samples_to_take:]
                sorted_samples_to_correct = [vanilla_samples_to_correct[prompt_idx] for prompt_idx in sorted_prompt_idxs]

                # Average samples-to-correct over hardest quantile
                x_task_model[task][model] = np.mean(sorted_samples_to_correct)
                continue

            elif x_axis_type == "p_at_1":
                # Get all json files from load path
                json_files = load_json_files(os.path.join(DATA_DIR, task, model), task)
                
                # load all json files
                p_at_1s = dict()
                for file in tqdm(json_files, desc=f"Reading correctness for all json files ..."):
                    with open(file, 'r') as f:
                        data = json.load(f)
                    prompt_idx = data[0]['prompt_idx']
                    try:
                        results = [d['strict_correct'] if 'strict_correct' in d else d['correct'] for d in data]
                    except:
                        print(file)
                        continue
                    p_at_1 = sum(results) / len(results)
                    p_at_1s[prompt_idx] = p_at_1

                if x_axis_type == "p_at_1_quantile":
                    # vanilla samples to get correct for ref model
                    with open(os.path.join(FIGURES_DIR, task, model, f"vanilla_samples_to_get_correct_ref_{REF_MODEL}_temp_{TASK_TO_TEMP[task]}_top_p_{TASK_TO_TOP_P[task]}.json"), "r") as f:
                        vanilla_samples_to_get_correct_ref = json.load(f)
                    
                    sorted_prompt_idxs = sorted(vanilla_samples_to_get_correct_ref.keys(), key=vanilla_samples_to_get_correct_ref.get)
                    num_samples = len(sorted_prompt_idxs)
                    num_samples_to_take = int(num_samples * QUANTILE_PERC)
                    sorted_prompt_idxs = sorted_prompt_idxs[-num_samples_to_take:]
                    p_at_1s = [p_at_1s[int(prompt_idx)] for prompt_idx in sorted_prompt_idxs]
                    p_at_1_task_model[task][model] = np.mean(p_at_1s)
                else:
                    x_task_model[task][model] = np.mean(list(p_at_1s.values()))

            elif x_axis_type == "rank_p_at_1":
                # load p_at_1
                full_path = os.path.join(DATA_DIR, "p_at_1_task_model.json")
                p_at_1_task_model = json_load(full_path)

                # find the rank of the current model according to p_at_1
                rank = sorted(p_at_1_task_model[task].items(), key=lambda x: x[1])
                rank = {model: i for i, (model, p_at_1) in enumerate(rank)}
                x_task_model[task][model] = rank[model]

    with open(os.path.join(DATA_DIR, file_name), "w") as f:
        json.dump(x_task_model, f)

    return x_task_model
    
def get_y_task_model(y_axis_type: str, remove_weak_models: bool = False) -> Dict[str, Dict[str, Tuple[float, float, float, float]]]:
    """
    Get y-axis data for each task and model.

    Args: 
        None

    Returns:
        y_task_model: Dictionary of task-model pairs, where each value is the y-axis data.
    """
    y_task_model = defaultdict(dict)
    for task in TASKS:
        models = os.listdir(os.path.join(FIGURES_DIR, task))
        models = filter_models(models, remove_weak_models=remove_weak_models)
        
        for model in tqdm(models):

            if not os.path.isdir(os.path.join(FIGURES_DIR, task, model)):
                continue

            # Load vanilla samples-to-correct for current model
            baseline_file_name = build_baseline_file_name(task, ref=False)
            file_path = os.path.join(FIGURES_DIR, task, model, baseline_file_name)
            vanilla_samples_to_correct = json_load(file_path)

            # Load elliptical samples-to-correct for current model
            elliptical_file_name = build_elliptical_file_name(task, model)
            file_path = os.path.join(FIGURES_DIR, task, model, elliptical_file_name)
            elliptical_samples_to_correct = json_load(file_path)
                
            if y_axis_type == "perc_improvement":
                vanilla_data = list(vanilla_samples_to_correct.values())
                elliptical_data = list(elliptical_samples_to_correct.values())

                def relative_improvement(vanilla_sample, elliptical_sample, axis=-1):
                    vanilla_mean = np.mean(vanilla_sample, axis=axis)
                    elliptical_mean = np.mean(elliptical_sample, axis=axis)
                    return (vanilla_mean - elliptical_mean) / vanilla_mean

                perc_improvement_mean = (np.mean(vanilla_data) - np.mean(elliptical_data)) / np.mean(vanilla_data) * 100
                # res = bootstrap((vanilla_data, elliptical_data), relative_improvement, n_resamples=10000)
                y_task_model[task][model] = (perc_improvement_mean, None, None, None)

            elif y_axis_type == "samples_to_correct_diff":
                vanilla_data = list(vanilla_samples_to_correct.values())
                elliptical_data = list(elliptical_samples_to_correct.values())

                def absolute_improvement(vanilla_sample, elliptical_sample, axis=-1):
                    vanilla_mean = np.mean(vanilla_sample, axis=axis)
                    elliptical_mean = np.mean(elliptical_sample, axis=axis)
                    return vanilla_mean - elliptical_mean

                perc_improvement_mean = np.mean(vanilla_data) - np.mean(elliptical_data)
                # ttest_res = ttest_ind(vanilla_data, elliptical_data)
                # ci = ttest_res.confidence_interval(confidence_level=0.95)
                res = bootstrap((vanilla_data, elliptical_data), absolute_improvement, n_resamples=10000)
                y_task_model[task][model] = (perc_improvement_mean, res.confidence_interval.low, res.confidence_interval.high, None)

            elif y_axis_type == "samples_to_correct":
                assert elliptical_samples_to_correct.keys() == vanilla_samples_to_correct.keys()

                elliptical_data = [np.mean(elliptical_samples_to_correct[prompt_idx]) for prompt_idx in vanilla_samples_to_correct.keys()]
                y_task_model[task][model] = (np.mean(elliptical_data), None, None, None)

            elif y_axis_type == "auc_perc_improvement":
                elliptical_samples_to_correct = []
                vanilla_samples_to_correct = []
                for k in elliptical_samples_to_get_correct.keys():
                    elliptical_samples_to_correct.append(np.mean(elliptical_samples_to_get_correct[k]))
                    vanilla_samples_to_correct.append(vanilla_samples_to_get_correct[k])

                elliptical_samples_to_correct = np.array(elliptical_samples_to_correct)
                vanilla_samples_to_correct = np.array(vanilla_samples_to_correct)

                x_min = np.min(vanilla_samples_to_correct)
                x_max = np.max(vanilla_samples_to_correct)

                slope, intercept = np.polyfit(np.log(vanilla_samples_to_correct), np.log(elliptical_samples_to_correct), 1) # '1' indicates a linear fit (degree 1)
                xs = np.linspace(x_min, x_max, 5000)
                ys = np.exp(slope * np.log(xs) + intercept)

                # compute auc
                elliptical_auc = np.trapz(ys, xs)
                control_auc = np.trapz(xs, xs)
                perc_improvement_task_model[task][model] = ((control_auc - elliptical_auc) / control_auc * 100, None, None, None)

            elif y_axis_type == "log_auc_perc_improvement":
                elliptical_samples_to_correct = []
                vanilla_samples_to_correct = []
                for k in elliptical_samples_to_get_correct.keys():
                    elliptical_samples_to_correct.append(np.mean(elliptical_samples_to_get_correct[k]))
                    vanilla_samples_to_correct.append(vanilla_samples_to_get_correct[k])
                
                elliptical_samples_to_correct = np.array(elliptical_samples_to_correct)
                vanilla_samples_to_correct = np.array(vanilla_samples_to_correct)

                x_min = np.min(vanilla_samples_to_correct)
                x_max = np.max(vanilla_samples_to_correct)

                slope, intercept = np.polyfit(np.log(vanilla_samples_to_correct), np.log(elliptical_samples_to_correct), 1) # '1' indicates a linear fit (degree 1)
                xs = np.log(np.linspace(x_min, x_max, 5000))
                ys = slope * xs + intercept

                # compute auc
                elliptical_auc = np.trapz(ys, xs)
                control_auc = np.trapz(xs, xs)
                perc_improvement_task_model[task][model] = ((control_auc - elliptical_auc) / control_auc * 100, None, None, None)

    return y_task_model

def plot_scatter(task, model, elliptical_path, vanilla_path, ax, set_xlabel=True):
    elliptical_data = json_load(os.path.join(FIGURES_DIR, task, model, elliptical_path))
    vanilla_data = json_load(os.path.join(FIGURES_DIR, task, model, vanilla_path))

    elliptical_samples_to_correct = []
    vanilla_samples_to_correct = []
    for k in elliptical_data.keys():
        elliptical_samples_to_correct.append(np.mean(elliptical_data[k]))
        vanilla_samples_to_correct.append(vanilla_data[k])

    elliptical_samples_to_correct = np.array(elliptical_samples_to_correct)
    vanilla_samples_to_correct = np.array(vanilla_samples_to_correct)

    XLIM = 7000

    # Calculate the line of best fit
    # slope, intercept = np.polyfit(np.log(vanilla_samples_to_correct), np.log(elliptical_samples_to_correct), 1) # '1' indicates a linear fit (degree 1)

    slope, intercept, r_value, p_value, std_err = linregress(np.log(vanilla_samples_to_correct), np.log(elliptical_samples_to_correct))
    r_squared = r_value**2
    line_of_best_fit = np.exp(slope * np.log(np.linspace(1, XLIM, XLIM)) + intercept)

    # Plotting
    ax.plot(
        np.linspace(1, XLIM, XLIM),
        line_of_best_fit,
        # color=MODEL_TO_COLOR[model],
        color='red',
        label=f'$y = {np.exp(intercept):.2f}x^{{{slope:.2f}}}$',
        linewidth=2.5  # Added line thickness
        # add border
    )
    # plot y = x line
    ax.plot(np.linspace(1, XLIM, XLIM), np.linspace(1, XLIM, XLIM), color='black', linestyle='--')
    ax.scatter(
        vanilla_samples_to_correct,
        elliptical_samples_to_correct,
        s=40,
        marker=TASK_TO_SHAPE[task],
        facecolor=MODEL_TO_COLOR[model],
        edgecolor='black',
        linewidths=0.5,  # Make the edgecolor thinner
        alpha=0.7  # Add opacity
    )
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlim(left=1.0, right=XLIM)
    ax.set_ylim(bottom=1.0, top=XLIM)
    # plt.title(f"{task}, {policy}: elliptical vs. vanilla samples-to-correct ($R^2 = {r_squared:.2f}$)")
    if set_xlabel:
        ax.set_xlabel("Random samples-to-correct", fontsize=XLABEL_FONT_SIZE)
    ax.set_ylabel("RepExp samples-to-correct", fontsize=YLABEL_FONT_SIZE)
    return ax.legend(fontsize=14)

def create_task_handles(model_to_marker_size):
    task_handles = [
        Line2D(
            [0], 
            [0], 
            marker=marker, 
            color='gray', 
            linestyle='None', 
            markersize=np.sqrt(list(model_to_marker_size.values())[0]), 
            markeredgecolor='black', 
            label=TASK_TO_NICE_NAME[task], 
            markerfacecolor='gray'
        )
        for task, marker in TASK_TO_SHAPE.items()
    ]

    return task_handles

def create_model_handles(model_info_sorted, model_to_marker_size):
    model_handles = []
    last_family = None
    for family, model_size, model in model_info_sorted:
        if family != last_family:
            if last_family is not None:
                model_handles.append(
                    Line2D([0], [0], color='lightgray', linestyle='-', linewidth=0.5, label="")
                )
            last_family = family

        handle = Line2D(
            [0], [0],
            marker='o',
            color=MODEL_TO_COLOR[model],
            markerfacecolor=MODEL_TO_COLOR[model],
            markersize=np.sqrt(model_to_marker_size[model]),
            linestyle='None',
            label=MODEL_TO_LEGEND_LABEL[model]
        )
        model_handles.append(handle)

    return model_handles

def plot_hardness(elliptical_data, random_data, ref_data, title, ax, set_xlabel=True):
    elliptical_bin_to_avg_samples_to_correct, elliptical_bin_to_sem_samples_to_get_correct = get_minimum_hardness_to_avg_samples_to_get_correct(
        elliptical_data, ref_data, hardness_style="quantile_num_to_correct", delta=DELTA
    )
    random_bin_to_avg_samples_to_correct, random_bin_to_sem_samples_to_get_correct = get_minimum_hardness_to_avg_samples_to_get_correct(
        random_data, ref_data, hardness_style="quantile_num_to_correct", delta=DELTA
    )

    plot_method_data(elliptical_bin_to_avg_samples_to_correct, elliptical_bin_to_sem_samples_to_get_correct, "RepExp", DELTA, ax)
    plot_method_data(random_bin_to_avg_samples_to_correct, random_bin_to_sem_samples_to_get_correct, "Random", DELTA, ax)

    # Specify plot details
    ax.set_title(f"{title}", fontweight="bold", fontsize=HARDNESS_TITLE_FONT_SIZE)
    if set_xlabel:
        ax.set_xlabel('Hardness quantile (%)', fontsize=XLABEL_FONT_SIZE)
    ax.set_ylabel('Samples-to-correct', fontsize=YLABEL_FONT_SIZE)
    ax.set_xlim(0, 1)
    ax.set_ylim(bottom=0)
    ax.set_xticks(np.arange(0, 1.0 + DELTA, DELTA))
    ax.set_xticklabels(['0', '10', '20', '30', '40', '50', '60', '70', '80', '90', '100'])

    for label in ax.get_xticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)
    for label in ax.get_yticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)

    return ax.legend(fontsize=HARDNESS_LEGEND_FONT_SIZE)

def main():
    # Get x values
    x_task_model = get_x_task_model(X_AXIS_TYPE, remove_weak_models=REMOVE_WEAK_MODELS)

    # Get y values
    y_task_model = get_y_task_model(Y_AXIS_TYPE, remove_weak_models=REMOVE_WEAK_MODELS)

    # Extract data for plotting
    all_xs = defaultdict(dict)
    all_ys = defaultdict(dict)
    all_ys_low = defaultdict(dict)
    all_ys_high = defaultdict(dict)
    all_p_values = defaultdict(dict)
    for task in y_task_model.keys():
        for model in y_task_model[task].keys():
            all_xs[task][model] = x_task_model[task][model]
            all_ys[task][model] = y_task_model[task][model][0]
            all_ys_low[task][model] = y_task_model[task][model][1]
            all_ys_high[task][model] = y_task_model[task][model][2]
            if HIGHLIGHT_SIGNIFICANT:
                all_p_values[task][model] = y_task_model[task][model][3]
            else:
                all_p_values[task][model] = None

    sns.set_style("whitegrid")
    # Create a figure with 3 subplots: one on the left, two stacked on the right
    from matplotlib.gridspec import GridSpec

    if PLOT_TYPE == "model_strength_plot":
        fig = plt.figure(figsize=(16, 8))
    else:
        fig = plt.figure(figsize=(16, 8))
    gs = GridSpec(2, 2, width_ratios=[2, 1], height_ratios=[1, 1], figure=fig)
    # gs.update(wspace=0.05, hspace=0.2)

    # Main (large) plot on the left
    ax = fig.add_subplot(gs[:, 0])

    # Top right subplot
    ax_top_right = fig.add_subplot(gs[0, 1])

    # Bottom right subplot
    ax_bottom_right = fig.add_subplot(gs[1, 1])
    # Ensure the top-right subplot renders above cross-axes connection lines
    ax_top_right.set_zorder(20)
    ax_bottom_right.set_zorder(20)
    ax_top_right.patch.set_alpha(1.0)
    ax_top_right.patch.set_zorder(21)
    ax_bottom_right.patch.set_alpha(1.0)
    ax_bottom_right.patch.set_zorder(21)
    # Add a bold border around the top-right subplot
    for spine in ax_top_right.spines.values():
        spine.set_linewidth(2)
        spine.set_edgecolor('black')
    for spine in ax_bottom_right.spines.values():
        spine.set_linewidth(2)
        spine.set_edgecolor('black')
    # Legends returned from inset scatter plots (to be re-added later)
    scatter_legend_top_right = None
    scatter_legend_bottom_right = None
    hardness_legend_top_right = None
    hardness_legend_bottom_right = None

    for task in all_xs.keys():
        for model in all_xs[task].keys():
            ax.scatter(
                all_xs[task][model], 
                all_ys[task][model], 
                marker=TASK_TO_SHAPE[task], 
                color=MODEL_TO_COLOR[model], 
                s=MODEL_TO_MARKER_SIZE[model], 
                edgecolor="black" if all_p_values[task][model] is None else "red" if all_p_values[task][model] < 0.05 else "black",
                zorder=3,
            )

            if model == 'qwen-25-14b':
                print(f"relative improvement {task}: {(all_xs[task][model] - all_ys[task][model]) / all_xs[task][model] * 100:.2f}%")

            if PLOT_TYPE == "teaser":

                if task == "game24" and model == "qwen-25-7b":

                    # load elliptical data
                    elliptical_path = [
                        f for f in os.listdir(os.path.join(FIGURES_DIR, task, model)) if "elliptical" in f and f.endswith(".json")
                    ][0]
                    vanilla_path = [
                        f for f in os.listdir(os.path.join(FIGURES_DIR, task, model)) if "vanilla" in f and not "ref" in f and f.endswith(".json")
                    ][0]

                    print("found elliptical path: ", elliptical_path)
                    print("found vanilla path: ", vanilla_path)

                    scatter_legend_top_right = plot_scatter(task, model, elliptical_path, vanilla_path, ax_top_right, set_xlabel=False)
                    
                    draw_log_square(ax, all_xs[task][model], all_ys[task][model], size=0.15, color="black")
                    # Compute square corners in data coordinates (log-space sized square)
                    x0 = all_xs[task][model]
                    y0 = all_ys[task][model]
                    size_log10 = 0.15
                    x_min = x0 * (10 ** (-size_log10 / 2))
                    x_max = x0 * (10 ** (size_log10 / 2))
                    y_min = y0 * (10 ** (-size_log10 / 2))
                    y_max = y0 * (10 ** (size_log10 / 2))

                    corners_data = [
                        # (x_min, y_min),  # lower-left
                        (x_min, y_max),  # upper-left
                        (x_max, y_min),  # lower-right
                        # (x_max, y_max),  # upper-right
                    ]
                    corners_axes = [
                        # (0.0, 0.0),  # lower-left of ax_top_right
                        (0.0, 1.0),  # upper-left
                        (1.0, 0.0),  # lower-right
                        # (1.0, 1.0),  # upper-right
                    ]

                    # Draw straight connection lines from each square corner to the corresponding subplot corner
                    for (xa, ya), (xb, yb) in zip(corners_data, corners_axes):
                        line = ConnectionPatch(
                            xyA=(xa, ya), xyB=(xb, yb),
                            coordsA="data", coordsB="axes fraction",
                            axesA=ax, axesB=ax_top_right,
                            arrowstyle="-",
                            color="black",
                            linewidth=1,
                            zorder=0.5,
                            linestyle="dashed"
                        )
                        line.set_clip_on(False)
                        line.set_in_layout(False)
                        ax.add_artist(line)

                if task == "mbpp" and model == "phi-4":
                    # load elliptical data
                    elliptical_path = [
                        f for f in os.listdir(os.path.join(FIGURES_DIR, task, model)) if "elliptical" in f and f.endswith(".json")
                    ][0]
                    vanilla_path = [
                        f for f in os.listdir(os.path.join(FIGURES_DIR, task, model)) if "vanilla" in f and not "ref" in f and f.endswith(".json")
                    ][0]

                    print("found elliptical path: ", elliptical_path)
                    print("found vanilla path: ", vanilla_path)

                    scatter_legend_bottom_right = plot_scatter(task, model, elliptical_path, vanilla_path, ax_bottom_right)

                    draw_log_square(ax, all_xs[task][model], all_ys[task][model], size=0.15, color="black")
                    # Compute square corners in data coordinates (log-space sized square)
                    x0 = all_xs[task][model]
                    y0 = all_ys[task][model]
                    size_log10 = 0.15
                    x_min = x0 * (10 ** (-size_log10 / 2))
                    x_max = x0 * (10 ** (size_log10 / 2))
                    y_min = y0 * (10 ** (-size_log10 / 2))
                    y_max = y0 * (10 ** (size_log10 / 2))

                    corners_data = [
                        # (x_min, y_min),  # lower-left
                        (x_min, y_max),  # upper-left
                        (x_max, y_min),  # lower-right
                        # (x_max, y_max),  # upper-right
                    ]
                    corners_axes = [
                        # (0.0, 0.0),  # lower-left of ax_top_right
                        (0.0, 1.0),  # upper-left
                        (1.0, 0.0),  # lower-right
                        # (1.0, 1.0),  # upper-right
                    ]

                    # Draw straight connection lines from each square corner to the corresponding subplot corner
                    for (xa, ya), (xb, yb) in zip(corners_data, corners_axes):
                        line = ConnectionPatch(
                            xyA=(xa, ya), xyB=(xb, yb),
                            coordsA="data", coordsB="axes fraction",
                            axesA=ax, axesB=ax_bottom_right,
                            arrowstyle="-",
                            color="black",
                            linewidth=1,
                            zorder=0.5,
                            linestyle="dashed"
                        )
                        line.set_clip_on(False)
                        line.set_in_layout(False)
                        ax.add_artist(line)

            elif PLOT_TYPE == "model_strength_plot":
                if task == "game24" and model == "phi-4":
                    phi_4_elliptical_data = json_load(PHI_4_GAME24_ELLIPTICAL_PATH)
                    phi_4_random_data = json_load(PHI_4_GAME24_RANDOM_PATH)
                    phi_4_ref_data = json_load(PHI_4_GAME24_REF_PATH)

                     # post-process elliptical data
                    phi_4_elliptical_data = {int(k): np.mean(v) for k, v in phi_4_elliptical_data.items()}
                    # post-process random data
                    phi_4_random_data = {int(k): v for k, v in phi_4_random_data.items()}
                    # post-process ref data
                    phi_4_ref_data = {int(k): v for k, v in phi_4_ref_data.items()}

                    hardness_legend_bottom_right = plot_hardness(phi_4_elliptical_data, phi_4_random_data, phi_4_ref_data, "Phi-4 on Game of 24", ax_bottom_right, set_xlabel=True)

                    size_px = 30
                    draw_square_pixels(ax, all_xs[task][model], all_ys[task][model], size_px=size_px, color="black")
                    # Compute square corners in data coordinates (match pixel-sized square)
                    x0 = all_xs[task][model]
                    y0 = all_ys[task][model]
                    x_min, x_max, y_min, y_max = get_pixel_square_bounds(ax, x0, y0, size_px=size_px)

                    corners_data = [
                        # (x_min, y_min),  # lower-left
                        (x_min, y_max),  # upper-left
                        (x_max, y_min),  # lower-right
                        # (x_max, y_max),  # upper-right
                    ]
                    corners_axes = [
                        # (0.0, 0.0),  # lower-left of ax_top_right
                        (0.0, 1.0),  # upper-left
                        (1.0, 0.0),  # lower-right
                        # (1.0, 1.0),  # upper-right
                    ]

                    # Draw straight connection lines from each square corner to the corresponding subplot corner
                    for (xa, ya), (xb, yb) in zip(corners_data, corners_axes):
                        line = ConnectionPatch(
                            xyA=(xa, ya), xyB=(xb, yb),
                            coordsA="data", coordsB="axes fraction",
                            axesA=ax, axesB=ax_bottom_right,
                            arrowstyle="-",
                            color="black",
                            linewidth=1,
                            zorder=0.5,
                            linestyle="dashed"
                        )
                        line.set_clip_on(False)
                        line.set_in_layout(False)
                        ax.add_artist(line)

                if task == "math" and model == "qwen-25-14b":
                    qwen_14b_elliptical_data = json_load(QWEN_14B_MATH_ELLIPTICAL_PATH)
                    qwen_14b_random_data = json_load(QWEN_14B_MATH_RANDOM_PATH)
                    qwen_14b_ref_data = json_load(QWEN_14B_MATH_REF_PATH)

                     # post-process elliptical data
                    qwen_14b_elliptical_data = {int(k): np.mean(v) for k, v in qwen_14b_elliptical_data.items()}
                    # post-process random data
                    qwen_14b_random_data = {int(k): v for k, v in qwen_14b_random_data.items()}
                    # post-process ref data
                    qwen_14b_ref_data = {int(k): v for k, v in qwen_14b_ref_data.items()}

                    hardness_legend_top_right = plot_hardness(qwen_14b_elliptical_data, qwen_14b_random_data, qwen_14b_ref_data, "Qwen-2.5-14B-Instruct on MATH", ax_top_right, set_xlabel=False)

                    size_px = 30
                    draw_square_pixels(ax, all_xs[task][model], all_ys[task][model], size_px=size_px, color="black")
                    # Compute square corners in data coordinates (match pixel-sized square)
                    x0 = all_xs[task][model]
                    y0 = all_ys[task][model]
                    x_min, x_max, y_min, y_max = get_pixel_square_bounds(ax, x0, y0, size_px=size_px)

                    corners_data = [
                        # (x_min, y_min),  # lower-left
                        (x_min, y_max),  # upper-left
                        (x_max, y_min),  # lower-right
                        # (x_max, y_max),  # upper-right
                    ]
                    corners_axes = [
                        # (0.0, 0.0),  # lower-left of ax_top_right
                        (0.0, 1.0),  # upper-left
                        (1.0, 0.0),  # lower-right
                        # (1.0, 1.0),  # upper-right
                    ]

                    # Draw straight connection lines from each square corner to the corresponding subplot corner
                    for (xa, ya), (xb, yb) in zip(corners_data, corners_axes):
                        line = ConnectionPatch(
                            xyA=(xa, ya), xyB=(xb, yb),
                            coordsA="data", coordsB="axes fraction",
                            axesA=ax, axesB=ax_top_right,
                            arrowstyle="-",
                            color="black",
                            linewidth=1,
                            zorder=0.5,
                            linestyle="dashed"
                        )
                        line.set_clip_on(False)
                        line.set_in_layout(False)
                        ax.add_artist(line)

            # Optionally add error bars
            if INCLUDE_ERROR_BARS:
                ax.errorbar(
                    all_xs[task][model], 
                    all_ys[task][model], 
                    yerr=[[all_ys[task][model] - all_ys_low[task][model]], [all_ys_high[task][model] - all_ys[task][model]]], 
                    color='lightgray', 
                    fmt='none', 
                    capsize=5, 
                    zorder=-1
                )

    # Add a line at 0
    if Y_AXIS_TYPE == "samples_to_correct" and X_AXIS_TYPE == "samples_to_correct":
        # Draw a 45-degree line that fills the current axis limits
        x_max = ax.get_xlim()[1]
        y_max = ax.get_ylim()[1]
        max_val = 2000 #max(x_max, y_max)
        min_val = 1.0 if (X_LOG_SCALE and Y_LOG_SCALE) else 0.0

        # Create arrays for fill_between
        fill_x = np.linspace(min_val, max_val, 500)
        fill_y = fill_x

        if X_LOG_SCALE and Y_LOG_SCALE:
            # Green below the line
            ax.fill_between(
                fill_x, min_val, fill_y,
                color='green', alpha=0.08, zorder=-2
            )
            # Red above the line
            ax.fill_between(
                fill_x, fill_y, max_val,
                color='red', alpha=0.08, zorder=-2
            )
            ax.plot([min_val, max_val], [min_val, max_val], color='black', linestyle='--', linewidth=1.0)
            ax.set_xlim(1.0, max_val)
            ax.set_ylim(1.0, max_val)

            # ax.text(
            #     0.05, 0.95, 
            #     "Exploration hurts", 
            #     transform=ax.transAxes, 
            #     ha="left", va="top", 
            #     fontsize=13,
            #     fontweight="bold",
            #     alpha=0.7,
            #     rotation=0,
            #     color="red",
            #     zorder=10
            # )
            ax.text(
                0.95, 0.05, 
                "Exploration helps", 
                transform=ax.transAxes, 
                ha="right", va="bottom",
                fontsize=20,
                fontweight="bold",
                alpha=0.7,
                rotation=0,
                color="green",
                zorder=10
            )
        else:
            # Green below the line
            ax.fill_between(
                fill_x, min_val, fill_y,
                color='green', alpha=0.08, zorder=-2
            )
            # Red above the line
            ax.fill_between(
                fill_x, fill_y, max_val,
                color='red', alpha=0.08, zorder=-2
            )
            ax.plot([min_val, max_val], [min_val, max_val], color='black', linestyle='--', linewidth=1.0)

            # Add text annotation for "exploration helps" in green region (bottom right)
            ax.text(
                max_val / 1.2, min_val * 2,
                "exploration helps",
                color="green",
                fontsize=18,
                fontweight="bold",
                alpha=0.7,
                rotation=0,
                ha="right",
                va="bottom",
                zorder=10
            )
            # "exploration hurts" in red region (top left)
            # ax.text(
            #     min_val * 2, max_val / 1.05,
            #     "exploration hurts",
            #     color="red",
            #     fontsize=13,
            #     fontweight="bold",
            #     alpha=0.7,
            #     rotation=0,
            #     ha="left",
            #     va="top",
            #     zorder=10
            # )
    
    elif PLOT_TYPE == "model_strength_plot":
        ax.axhline(y=0, color='black', linestyle='--', linewidth=1.0)
        ax.set_xlim(-0.5, 8.5)
        ax.set_ylim(-175, 100)

        # color background green above the line, and red below the line
        fill_x = np.linspace(-1, 9, 500)
        y_max = ax.get_ylim()[1]
        y_min = ax.get_ylim()[0]

        # Green below the line
        ax.fill_between(
            fill_x, 0, y_max,
            color='green', alpha=0.08, zorder=-2
        )
        # Red above the line
        ax.fill_between(
            fill_x, y_min, 0,
            color='red', alpha=0.08, zorder=-2
        )

        # add leftpointing arrow at -100, 4
        ax.annotate(
            '',
            xy=(6.5, -125),    # arrow tip (right of text)
            xytext=(4.5, -125),  # arrow tail (at text)
            arrowprops=dict(arrowstyle='->,head_length=0.5,head_width=0.5', color='black', lw=3)
        )
        # Center the text above the arrow
        arrow_center_x_stronger = (4.5 + 6.5) / 2
        ax.text(
            arrow_center_x_stronger, -117,
            "Stronger models",
            fontsize=16,
            fontweight="bold",
            alpha=0.9,
            ha="center"
        )

        # add leftpointing arrow at -100, 4
        ax.annotate(
            '',
            xy=(1.5, -125),    # arrow tip (left of text)
            xytext=(3.5, -125),  # arrow tail (at text)
            arrowprops=dict(arrowstyle='->,head_length=0.5,head_width=0.5', color='black', lw=3)
        )
        # Center the text above the arrow
        arrow_center_x = (1.5 + 3.5) / 2
        ax.text(
            arrow_center_x, -117,
            "Weaker models",
            fontsize=16,
            fontweight="bold",
            alpha=0.9,
            ha="center"
        )
    else:
        ax.axhline(y=0, color='black', linestyle='--', linewidth=1.0)

    if X_LOG_SCALE:
        ax.set_xscale('log')
    if Y_LOG_SCALE:
        ax.set_yscale('log')

    # Add unlabeled minor ticks at 2..9 * 10^n between major ticks on the main (left) plot
    # Ensure minor ticks are turned on and visible
    if X_LOG_SCALE or Y_LOG_SCALE:
        ax.minorticks_on()
        ax.tick_params(axis='both', which='minor', length=4, width=1.5, bottom=True, left=True)
        
        ax_top_right.minorticks_on()
        ax_top_right.tick_params(axis='both', which='minor', length=4, width=1.5, bottom=True, left=True)

        ax_bottom_right.minorticks_on()
        ax_bottom_right.tick_params(axis='both', which='minor', length=4, width=1.5, bottom=True, left=True)

    # Create task handles
    task_handles = create_task_handles(MODEL_TO_MARKER_SIZE)

    # Add task legend
    if REMOVE_WEAK_MODELS: # assume we're plotting teaser
        legend1 = ax_top_right.legend(handles=task_handles, title="Task", loc="lower left", bbox_to_anchor=(1.02, -0.63), borderaxespad=0.0, fontsize=OUTER_LEGEND_FONT_SIZE, title_fontsize=OUTER_LEGEND_FONT_SIZE + 2)
        ax_top_right.add_artist(legend1)
    else: # assume we're plotting model strength plot
        legend1 = ax_bottom_right.legend(handles=task_handles, title="Task", loc="lower left", bbox_to_anchor=(1.06, 0), borderaxespad=0.0, fontsize=OUTER_LEGEND_FONT_SIZE, title_fontsize=OUTER_LEGEND_FONT_SIZE + 2)

    if REMOVE_WEAK_MODELS:
        model_info_sorted = [model for model in MODEL_INFO_SORTED if model[2] not in WEAK_MODELS]
    else:
        model_info_sorted = MODEL_INFO_SORTED

    # Create model handles
    model_handles = create_model_handles(model_info_sorted, MODEL_TO_MARKER_SIZE)

    # Add model legend
    if PLOT_TYPE == "model_strength_plot":
        legend2 = ax_top_right.legend(handles=model_handles, title="Model", loc="upper left", bbox_to_anchor=(1.06, 1.0), borderaxespad=0.0, fontsize=OUTER_LEGEND_FONT_SIZE, title_fontsize=OUTER_LEGEND_FONT_SIZE + 2)
    else:
        legend2 = ax_top_right.legend(handles=model_handles, title="Model", loc="upper left", bbox_to_anchor=(1.02, 1.0), borderaxespad=0.0, fontsize=OUTER_LEGEND_FONT_SIZE, title_fontsize=OUTER_LEGEND_FONT_SIZE + 2)
    ax_top_right.add_artist(legend2)

    # Re-add the scatter legends so they're not overridden by later legend() calls
    if scatter_legend_top_right is not None:
        ax_top_right.add_artist(scatter_legend_top_right)
    if scatter_legend_bottom_right is not None:
        ax_bottom_right.add_artist(scatter_legend_bottom_right)

    if hardness_legend_top_right is not None:
        ax_top_right.add_artist(hardness_legend_top_right)
    if hardness_legend_bottom_right is not None:
        ax_bottom_right.add_artist(hardness_legend_bottom_right)

    for label in ax_top_right.get_xticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)
    for label in ax_top_right.get_yticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)
    for label in ax_bottom_right.get_xticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)
    for label in ax_bottom_right.get_yticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)

    # Set axis labels
    x_label = get_x_label(X_AXIS_TYPE)
    ax.set_xlabel(x_label, fontsize=XLABEL_FONT_SIZE)
    for label in ax.get_xticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)
    
    ylabel = get_y_label(Y_AXIS_TYPE)
    ax.set_ylabel(ylabel, fontsize=YLABEL_FONT_SIZE)
    for label in ax.get_yticklabels():
        label.set_fontsize(TICK_LABEL_FONT_SIZE)
    
    # Ensure tick labels are on top of connector lines for all subplots
    for label in ax.get_xticklabels():
        label.set_zorder(10)
    for label in ax.get_yticklabels():
        label.set_zorder(10)
    for label in ax_top_right.get_xticklabels():
        label.set_zorder(10)
    for label in ax_top_right.get_yticklabels():
        label.set_zorder(10)
    for label in ax_bottom_right.get_xticklabels():
        label.set_zorder(10)
    for label in ax_bottom_right.get_yticklabels():
        label.set_zorder(10)

    # Save figure
    save_file_name = f"{Y_AXIS_TYPE}_vs_{X_AXIS_TYPE}_q={QUANTILE_PERC}_CI_{INCLUDE_ERROR_BARS}.pdf"
    if REMOVE_WEAK_MODELS:
        save_file_name = save_file_name.replace(".pdf", "_remove_weak_models.pdf")
    plt.tight_layout(pad=0.6, w_pad=0.2, h_pad=0.6)
    if PLOT_TYPE == "model_strength_plot":
        plt.subplots_adjust(wspace=0.18)
    elif PLOT_TYPE == "teaser":
        plt.subplots_adjust(wspace=0.12)
    # Prepare bbox extra artists to avoid legends being cut off
    bbox_artists = [legend1, legend2]
    if scatter_legend_top_right is not None:
        bbox_artists.append(scatter_legend_top_right)
    if scatter_legend_bottom_right is not None:
        bbox_artists.append(scatter_legend_bottom_right)
    if hardness_legend_top_right is not None:
        bbox_artists.append(hardness_legend_top_right)
    if hardness_legend_bottom_right is not None:
        bbox_artists.append(hardness_legend_bottom_right)

    plt.savefig(
        os.path.join(FIGURES_DIR, save_file_name),
        bbox_inches='tight',
        bbox_extra_artists=tuple(bbox_artists)
    )

if __name__ == "__main__":
    main()