import os
import logging
from collections import defaultdict
from typing import Dict, List

import numpy as np
import hydra
from omegaconf import OmegaConf, DictConfig
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D

from inference_rlhf.code.helpers.utils import timing, estimate_pass_at_k
from inference_rlhf.code.coreset.elliptical_coreset import EllipticalCoreset
from inference_rlhf.code.helpers.utils import load_pool_data
from inference_rlhf.code.helpers.utils import set_seeds
from inference_rlhf.code.helpers.constructors import dataloader_factory
from inference_rlhf.code.helpers.io import json_dump, json_load

log = logging.getLogger(__name__)

LEGEND_FONT_SIZE = 6
LABEL_FONT_SIZE = 7.5
TICK_LABEL_FONT_SIZE = 6.4
TICK_LABEL_PAD = 0

RANDOM_VANILLA_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-7b/vanilla_pass_at_k_temp_1.0_top_p_1.0_min_p_0.0.json"
RANDOM_LOW_TEMP_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-7b/vanilla_pass_at_k_temp_0.6_top_p_1.0_min_p_0.0.json"
RANDOM_HIGH_TEMP_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-7b/vanilla_pass_at_k_temp_1.5_top_p_1.0_min_p_0.0.json"
RANDOM_MIN_P_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-7b/vanilla_pass_at_k_temp_1.5_top_p_1.0_min_p_0.05.json"
RANDOM_NUCLEUS_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-7b/vanilla_pass_at_k_temp_1.0_top_p_0.9_min_p_0.0.json"

ELLIPTICAL_VANILLA_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-7b/elliptical_pass_at_k_temp_1.0_top_p_1.0_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_qwen-25-7b_mean_hidden_state_center_features.json"
ELLIPTICAL_LOW_TEMP_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-7b/elliptical_pass_at_k_temp_0.6_top_p_1.0_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_qwen-25-7b_mean_hidden_state_center_features.json"
ELLIPTICAL_HIGH_TEMP_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-7b/elliptical_pass_at_k_temp_1.5_top_p_1.0_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_qwen-25-7b_mean_hidden_state_center_features.json"
ELLIPTICAL_MIN_P_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-7b/elliptical_pass_at_k_temp_1.5_top_p_1.0_min_p_0.05_lamb_1.0_sparse_dim_512_elliptical_feature_qwen-25-7b_mean_hidden_state_center_features.json"
ELLIPTICAL_NUCLEUS_PATH = "anonymous/anonymous/inference-rlhf/figures/math/qwen-25-7b/elliptical_pass_at_k_temp_1.0_top_p_0.9_min_p_0.0_lamb_1.0_sparse_dim_512_elliptical_feature_qwen-25-7b_mean_hidden_state_center_features.json"

# FACE_COLOR = "black"
FACE_COLOR = "#F7F7FF"
MARKEREDGEWIDTH = 0.6
MARKERSIZE = 4.5
LINEWIDTH = 1.2

POOL_TYPE_TO_COLOR = {
    "vanilla": "#E55A5A",
    "low_temp": "#FF8C00",
    "high_temp": "#7D3C98",
    "min_p": "#4CAF7A",
    "nucleus": "#4A90E2",
}

def plot_pool(random_data, elliptical_data, ax, pool_type: str, y_min: float, set_xlabel: bool = True, set_ylabel: bool = True, mark_arrow: bool = False, marker_delta: float = 0.0):
    # Get data ready
    random_xs = list(map(int, random_data.keys()))
    random_ys = [np.mean(random_data[str(x)]) for x in random_xs]
    elliptical_ys = [np.mean(elliptical_data[str(x)]) for x in random_xs]

    # Plot
    ax.plot(
        random_xs, 
        random_ys, 
        color=POOL_TYPE_TO_COLOR[pool_type], 
        marker='^', 
        alpha=0.5, 
        markeredgewidth=MARKEREDGEWIDTH, 
        markeredgecolor=FACE_COLOR, 
        markersize=MARKERSIZE,
        linewidth=LINEWIDTH
    )
    ax.plot(
        random_xs, 
        elliptical_ys, 
        color=POOL_TYPE_TO_COLOR[pool_type], 
        marker='o', 
        markeredgewidth=MARKEREDGEWIDTH, 
        markeredgecolor=FACE_COLOR, 
        markersize=MARKERSIZE,
        linewidth=LINEWIDTH
    )

    ax.set_xscale('log', base=2)
    x_ticks = [2**i for i in range(0, int(np.log2(max(random_xs))) + 1, 2)]
    x_tick_labels = [f"$2^{{{i}}}$" for i in range(0, int(np.log2(max(random_xs))) + 1, 2)]
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_tick_labels, fontsize=TICK_LABEL_FONT_SIZE)
    ax.set_ylim(bottom=y_min)
    # Bring tick labels closer to the axes
    ax.tick_params(axis='both', which='both', pad=TICK_LABEL_PAD)
    # set y ticks to be 0.4, 0.5, 0.6, 0.7, 0.8
    if pool_type == "high_temp":
        ax.set_yticks([0.4, 0.5, 0.6, 0.7, 0.8])
        ax.set_yticklabels(['0.4', '0.5', '0.6', '0.7', '0.8'], fontsize=TICK_LABEL_FONT_SIZE)
    else:
        ax.set_yticks([0.5, 0.6, 0.7, 0.8])
        ax.set_yticklabels(['0.5', '0.6', '0.7', '0.8'], fontsize=TICK_LABEL_FONT_SIZE)
    # ax.set_xlim(left=1, right=max(random_xs))

    if set_xlabel:
        ax.set_xlabel('k', fontsize=LABEL_FONT_SIZE)

    if set_ylabel:
        ax.set_ylabel('Pass@k', fontsize=LABEL_FONT_SIZE)

    if mark_arrow:
        # Numpify
        random_xs = np.array(random_xs, dtype=float)
        random_ys = np.array(random_ys, dtype=float)
        elliptical_ys = np.array(elliptical_ys, dtype=float)

        y_min_overlap = max(random_ys.min(), elliptical_ys.min())
        y_max_overlap = min(random_ys.max(), elliptical_ys.max())

        y_grid = np.linspace(y_min_overlap, y_max_overlap, 400)
        x_random_at_y = np.interp(y_grid, random_ys, random_xs)
        x_elliptical_at_y = np.interp(y_grid, elliptical_ys, random_xs)
        # Work in log2(k) space to match the axis scaling
        dx_log = np.abs(np.log2(x_elliptical_at_y) - np.log2(x_random_at_y))
        max_idx = int(np.argmax(dx_log))
        y_star = y_grid[max_idx]
        x_random_star = x_random_at_y[max_idx]
        x_elliptical_star = x_elliptical_at_y[max_idx]
        # Draw double-headed arrow between the two x positions at y_star
        ax.annotate(
            '',
            xy=(x_random_star, y_star + marker_delta),
            xytext=(x_elliptical_star, y_star + marker_delta),
            arrowprops=dict(arrowstyle='<->,head_length=0.17,head_width=0.17', color=POOL_TYPE_TO_COLOR[pool_type], lw=1.2)
        )
        # Place a small label at the geometric mean x to look centered on a log-x axis
        mid_x = np.sqrt(x_random_star * x_elliptical_star)
        ax.text(
            mid_x,
            y_star + marker_delta + 0.008,  # move up slightly
            r"$\mathbf{\times" + f"{(x_random_star / x_elliptical_star):.1f}" + "}$",
            color=POOL_TYPE_TO_COLOR[pool_type],
            ha='center',
            va='bottom',
            fontsize=6.5
        )

        # vertical line at x_random_star & x_elliptical_star, up to y_star
        if marker_delta != 0:
            ax.vlines(x_random_star, ymin=y_min, ymax=y_star, color=POOL_TYPE_TO_COLOR[pool_type], linestyle='dashed', linewidth=1.0)
            ax.vlines(x_elliptical_star, ymin=y_min, ymax=y_star, color=POOL_TYPE_TO_COLOR[pool_type], linestyle='dashed', linewidth=1.0)

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    sns.set_theme(style="whitegrid")

    random_vanilla_data = json_load(RANDOM_VANILLA_PATH)
    random_low_temp_data = json_load(RANDOM_LOW_TEMP_PATH)
    random_high_temp_data = json_load(RANDOM_HIGH_TEMP_PATH)
    random_min_p_data = json_load(RANDOM_MIN_P_PATH)
    random_nucleus_data = json_load(RANDOM_NUCLEUS_PATH)

    elliptical_vanilla_data = json_load(ELLIPTICAL_VANILLA_PATH)
    elliptical_low_temp_data = json_load(ELLIPTICAL_LOW_TEMP_PATH)
    elliptical_high_temp_data = json_load(ELLIPTICAL_HIGH_TEMP_PATH)
    elliptical_min_p_data = json_load(ELLIPTICAL_MIN_P_PATH)
    elliptical_nucleus_data = json_load(ELLIPTICAL_NUCLEUS_PATH)

    # make subplots 2 x 3
    fig, axs = plt.subplots(2, 3, figsize=(6, 2.5))

    plot_pool(random_vanilla_data, elliptical_vanilla_data, axs[0, 0], "vanilla", 0.45, set_xlabel=False, mark_arrow=True, marker_delta=-0.18)
    plot_pool(random_low_temp_data, elliptical_low_temp_data, axs[0, 1], "low_temp", 0.45, set_xlabel=False, set_ylabel=False, mark_arrow=True, marker_delta=-0.15)
    plot_pool(random_high_temp_data, elliptical_high_temp_data, axs[0, 2], "high_temp", 0.3, set_xlabel=False, set_ylabel=False, mark_arrow=True)
    plot_pool(random_min_p_data, elliptical_min_p_data, axs[1, 0], "min_p", 0.43, mark_arrow=True, marker_delta=-0.18)
    plot_pool(random_nucleus_data, elliptical_nucleus_data, axs[1, 1], "nucleus", 0.45, set_ylabel=False, mark_arrow=True, marker_delta=-0.165)

    # make ax[1, 2] empty
    axs[1, 2].axis('off')

    # # plot all pools again on last axis
    # plot_pool(random_vanilla_data, elliptical_vanilla_data, axs[1, 2], "vanilla", 0.3, set_ylabel=False)
    # plot_pool(random_low_temp_data, elliptical_low_temp_data, axs[1, 2], "low_temp", 0.3, set_ylabel=False)
    # plot_pool(random_high_temp_data, elliptical_high_temp_data, axs[1, 2], "high_temp", 0.3, set_ylabel=False)
    # plot_pool(random_min_p_data, elliptical_min_p_data, axs[1, 2], "min_p", 0.3, set_ylabel=False)
    # plot_pool(random_nucleus_data, elliptical_nucleus_data, axs[1, 2], "nucleus", 0.3, set_ylabel=False)

    pool_handles = [
        Line2D([0], [0], color='#E55A5A', label='Vanilla'),
        Line2D([0], [0], color='#FF8C00', label='Low temp'),
        Line2D([0], [0], color='#7D3C98', label='High temp'),
        Line2D([0], [0], color="#4CAF7A", label='Min-p'),
        Line2D([0], [0], color="#4A90E2", label='Nucleus'),
    ]
    method_handles = [
        Line2D([0], [0], color='grey', marker='^', label='Random', alpha=0.5, markersize=MARKERSIZE),
        Line2D([0], [0], color='grey', marker='o', label='RepExp', markersize=MARKERSIZE),
    ]
    # Place the first legend (Data pool) at the upper left, and the second (Method) just to the right of it
    legend1 = axs[1, 2].legend(handles=pool_handles, title="Data pool", loc="upper left", bbox_to_anchor=(-0.06, 1), borderaxespad=0.0, fontsize=LEGEND_FONT_SIZE, title_fontsize=LEGEND_FONT_SIZE + 0.5)
    axs[1, 2].add_artist(legend1)
    legend2 = axs[1, 2].legend(handles=method_handles, title="Method", loc="upper right", bbox_to_anchor=(1, 1), borderaxespad=0.0, fontsize=LEGEND_FONT_SIZE, title_fontsize=LEGEND_FONT_SIZE + 0.5)
    axs[1, 2].add_artist(legend2)

    bbox_artists = [legend1, legend2]
    plt.savefig(os.path.join("figures", "math", "pass_at_k_alternative_log.pdf"), bbox_inches='tight', bbox_extra_artists=tuple(bbox_artists))
    plt.close()

if __name__ == "__main__":
    main()