import os
import logging
from collections import defaultdict
from typing import Dict, List

import numpy as np
import hydra
from omegaconf import OmegaConf, DictConfig
import matplotlib.pyplot as plt
import seaborn as sns

from inference_rlhf.code.helpers.utils import timing, estimate_pass_at_k
from inference_rlhf.code.coreset.elliptical_coreset import EllipticalCoreset
from inference_rlhf.code.helpers.utils import load_pool_data
from inference_rlhf.code.helpers.utils import set_seeds
from inference_rlhf.code.helpers.constructors import dataloader_factory
from inference_rlhf.code.helpers.io import json_dump

log = logging.getLogger(__name__)

FACE_COLOR = "#F7F7FF"
MARKER = "o"
LINEWIDTH = 1.2
MARKERSIZE = 10
MARKEREDGEWIDTH = 1.0

def construct_elliptical_data_name(
    task_name: str,
    policy_name: str,
    temp: float,
    top_p: float,
    min_p: float,
    sparse_dim: int,
    perform_sparse_projection: bool,
    feature_name: str,
    feature_type: str,
    perform_pca: bool,
    pca_dim: int,
    lamb: float,
    center_features: bool,
    debug: bool = False,
) -> str:
    file_name = "elliptical_pass_at_k"

    file_name += f"_temp_{temp}"
    file_name += f"_top_p_{top_p}"
    file_name += f"_min_p_{min_p}"

    file_name += f"_lamb_{lamb}"
    if perform_sparse_projection:
        file_name += f"_sparse_dim_{sparse_dim}"
    if perform_pca:
        file_name += f"_pca_dim_{pca_dim}"
    file_name += f"_elliptical_feature_{feature_name}_{feature_type}"
    if center_features:
        file_name += "_center_features"

    if debug:
        file_name = f"debug_{file_name}"

    file_name += ".json"

    os.makedirs(os.path.join('figures', task_name, policy_name), exist_ok=True)
    return os.path.join('figures', task_name, policy_name, file_name)

def construct_vanilla_data_name(
    task_name: str,
    policy_name: str,
    temp: float,
    top_p: float,
    min_p: float,
    ref_policy_name: str = None,
    debug: bool = False,
) -> str:
    file_name = "vanilla_pass_at_k"

    if ref_policy_name is not None:
        file_name += f"_ref_{ref_policy_name}"

    file_name += f"_temp_{temp}"
    file_name += f"_top_p_{top_p}"
    file_name += f"_min_p_{min_p}"

    if debug:
        file_name = f"debug_{file_name}"

    file_name += ".json"

    os.makedirs(os.path.join('figures', task_name, policy_name), exist_ok=True)
    return os.path.join('figures', task_name, policy_name, file_name)

@timing
def compute_pass_at_k(all_results: Dict[int, List[int]], max_k: int) -> Dict[int, List[float]]:
    pass_at_k = defaultdict(list)
    k_powers = [2**i for i in range(int(np.log2(max_k)) + 1)]
    for k in k_powers:
        for example_id, responses in all_results.items():
            pass_at_k[k].append(estimate_pass_at_k([len(responses)], [sum(responses)], k)[0])
    return pass_at_k

def construct_plot_save_path_name(
    task_name: str,
    policy_name: str, 
    coresets: str,
    sparse_dim: int,
    perform_sparse_projection: bool,
    feature_name: str,
    feature_type: str,
    perform_pca: bool,
    pca_dim: int,
    argmax: bool,
    lamb: float,
    pool_type: str,
    center_features: bool,
) -> str:
    file_name = f"{pool_type}_pass_at_k_{policy_name}_{coresets}"

    if "elliptical" in coresets:
        file_name += f"_lamb_{lamb}"
        if perform_sparse_projection:
            file_name += f"_sparse_dim_{sparse_dim}"
        if perform_pca:
            file_name += f"_pca_dim_{pca_dim}"
        file_name += f"_elliptical_feature_{feature_name}_{feature_type}"
        if argmax:
            file_name += "_argmax"
        else:
            file_name += "_sample"
        if center_features:
            file_name += "_center_features"

    file_name += ".pdf"

    return os.path.join('figures', task_name, file_name)

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    set_seeds(cfg.seed)

    dl = dataloader_factory(cfg.task.name, cfg)

    pool_cfg = None
    if cfg.plot.pool_type == "vanilla":
        pool_cfg = cfg.plot.vanilla_pool
    elif cfg.plot.pool_type == "low_temp":
        pool_cfg = cfg.plot.low_temp_pool
    elif cfg.plot.pool_type == "high_temp":
        pool_cfg = cfg.plot.high_temp_pool
    elif cfg.plot.pool_type == "min_p":
        pool_cfg = cfg.plot.min_p_pool
    elif cfg.plot.pool_type == "nucleus":
        pool_cfg = cfg.plot.nucleus_pool

    pool_data = load_pool_data(cfg, pool_cfg, dl, remove_all_incorrect=False)
    log.info(f"Number of problems in pool: {len(pool_data['results'])}")

    coresets_to_plot = []
    plot_labels = []

    for coreset in cfg.plot.coresets:
        all_answers = pool_data["answers"]
        all_results = pool_data["results"]
        all_features = pool_data["features"]

        if coreset == "elliptical":
            elliptical_coreset = EllipticalCoreset(cfg)
            elliptical_coreset_at_k = elliptical_coreset.pass_at_k(
                all_answers, 
                all_results, 
                all_features,
            )

            elliptical_data_name = construct_elliptical_data_name(
                task_name=cfg.task.name,
                policy_name=cfg.policy.name,
                temp=pool_cfg.sampling.temperature,
                top_p=pool_cfg.sampling.top_p,
                min_p=pool_cfg.sampling.min_p,
                sparse_dim=cfg.coreset.elliptical.sparse_dim,
                perform_sparse_projection=cfg.coreset.elliptical.perform_sparse_projection,
                feature_name=cfg.coreset.elliptical.feature_name,
                feature_type=cfg.coreset.elliptical.feature_type,
                perform_pca=cfg.coreset.elliptical.perform_pca,
                pca_dim=cfg.coreset.elliptical.pca_dim,
                lamb=cfg.coreset.elliptical.lamb,
                center_features=cfg.coreset.elliptical.center_features,
                debug=cfg.debug,
            )
            json_dump(elliptical_coreset_at_k, elliptical_data_name)

            coresets_to_plot.append({k: np.mean(v) for k, v in elliptical_coreset_at_k.items()})
            plot_labels.append("Elliptical")

        elif coreset == "vanilla":
            all_pass_at_k = compute_pass_at_k(all_results, cfg.coreset.max_k)
            mean_pass_at_k = {k: np.mean(v) for k, v in all_pass_at_k.items()}

            vanilla_data_name = construct_vanilla_data_name(
                task_name=cfg.task.name,
                policy_name=cfg.policy.name,
                temp=pool_cfg.sampling.temperature,
                top_p=pool_cfg.sampling.top_p,
                min_p=pool_cfg.sampling.min_p,
                debug=cfg.debug,
            )
            json_dump(all_pass_at_k, vanilla_data_name)

            coresets_to_plot.append(mean_pass_at_k)
            plot_labels.append("Vanilla")

    # Plot all coresets
    sns.set_theme(style="whitegrid")
    for coreset, label in zip(coresets_to_plot, plot_labels):
        plt.plot(
            coreset.keys(), 
            coreset.values(), 
            marker=MARKER,
            markersize=MARKERSIZE,
            linewidth=LINEWIDTH,
            markeredgewidth=MARKEREDGEWIDTH,
            markeredgecolor=FACE_COLOR, label=label
        )

        log.info(f"{label}@k: {coreset}")

    # Specify plot details
    plt.title(f"{cfg.task.name}: coresets ({cfg.policy.name})")
    plt.xscale('log', base=2)
    plt.xlabel('k')
    plt.ylabel('Pass@k')
    plt.legend(fontsize=10)
    plt.ylim(min(mean_pass_at_k.values()) - 0.05, 1.0)
    x_ticks = [2**i for i in range(int(np.log2(max(mean_pass_at_k.keys()))) + 1)]
    x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(mean_pass_at_k.keys()))) + 1)]
    plt.xticks(x_ticks, x_tick_labels)

    plot_save_path = construct_plot_save_path_name(
        cfg.task.name,
        cfg.policy.name,
        cfg.plot.coresets,
        cfg.coreset.elliptical.sparse_dim,
        cfg.coreset.elliptical.perform_sparse_projection,
        cfg.coreset.elliptical.feature_name,
        cfg.coreset.elliptical.feature_type,
        cfg.coreset.elliptical.perform_pca,
        cfg.coreset.elliptical.pca_dim,
        cfg.coreset.elliptical.argmax,
        cfg.coreset.elliptical.lamb,
        cfg.plot.pool_type,
        cfg.coreset.elliptical.center_features,
    )
    print(plot_save_path)
    plt.savefig(plot_save_path)
    plt.close()

if __name__ == "__main__":
    main()