import os
import logging
from typing import Dict, List, Tuple
from collections import Counter
import json

import numpy as np
import hydra
from omegaconf import OmegaConf, DictConfig
from importlib import import_module
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

from inference_rlhf.code.helpers.utils import timing, estimate_pass_at_k
from inference_rlhf.code.coreset.elliptical_coreset import EllipticalCoreset
from inference_rlhf.code.helpers.utils import load_response_data, maybe_filter_data, load_ref_data
from inference_rlhf.code.helpers.utils import set_seeds, load_pool_data
from inference_rlhf.code.helpers.constructors import dataloader_factory
from inference_rlhf.code.tasks.base import BaseDataLoader
from inference_rlhf.code.helpers.io import json_dump

log = logging.getLogger(__name__)

FACE_COLOR = "#F7F7FF"
MARKER = "o"
LINEWIDTH = 1.2
MARKERSIZE = 10
MARKEREDGEWIDTH = 1.0

@timing
def compute_pass_at_1(all_results: Dict[int, List[int]]) -> Dict[int, float]:
    pass_at_1 = dict()
    for example_id, responses in all_results.items():
        pass_at_1[example_id] = estimate_pass_at_k([len(responses)], [sum(responses)], 1)[0]
    return pass_at_1

def construct_vanilla_data_name(
    task_name: str,
    policy_name: str,
    temp: float,
    top_p: float,
    min_p: float,
    ref_policy_name: str = None,
    debug: bool = False,
) -> str:
    file_name = "vanilla_samples_to_get_correct"

    if ref_policy_name is not None:
        file_name += f"_ref_{ref_policy_name}"

    file_name += f"_temp_{temp}"
    file_name += f"_top_p_{top_p}"
    file_name += f"_min_p_{min_p}"

    if debug:
        file_name = f"debug_{file_name}"

    file_name += ".json"

    os.makedirs(os.path.join('figures', task_name, policy_name), exist_ok=True)
    return os.path.join('figures', task_name, policy_name, file_name)

def construct_elliptical_data_name(
    task_name: str,
    policy_name: str,
    temp: float,
    top_p: float,
    min_p: float,
    sparse_dim: int,
    perform_sparse_projection: bool,
    feature_name: str,
    feature_type: str,
    perform_pca: bool,
    pca_dim: int,
    lamb: float,
    center_features: bool,
    debug: bool = False,
) -> str:
    file_name = "elliptical_samples_to_get_correct"

    file_name += f"_temp_{temp}"
    file_name += f"_top_p_{top_p}"
    file_name += f"_min_p_{min_p}"

    file_name += f"_lamb_{lamb}"
    if perform_sparse_projection:
        file_name += f"_sparse_dim_{sparse_dim}"
    if perform_pca:
        file_name += f"_pca_dim_{pca_dim}"
    file_name += f"_elliptical_feature_{feature_name}_{feature_type}"
    if center_features:
        file_name += "_center_features"

    if debug:
        file_name = f"debug_{file_name}"

    file_name += ".json"

    os.makedirs(os.path.join('figures', task_name, policy_name), exist_ok=True)
    return os.path.join('figures', task_name, policy_name, file_name)

def construct_plot_save_path_name(
    task_name: str,
    policy_name: str, 
    coresets: str,
    pool_cfg_list: List[DictConfig],
    hardness_style: str,
    ref_policy_name: str,
    debug: bool
) -> str:
    file_name = f"hardness_{hardness_style}_{policy_name}_ref_{ref_policy_name}_{coresets}"

    if debug:
        file_name = f"debug_{file_name}"

    for coreset, pool_cfg in zip(coresets, pool_cfg_list):
        if coreset == "elliptical":
            file_name += f"_lamb_{pool_cfg.elliptical.lamb}"
            if pool_cfg.elliptical.perform_sparse_projection:
                file_name += f"_sparse_dim_{pool_cfg.elliptical.sparse_dim}"
            if pool_cfg.elliptical.perform_pca:
                file_name += f"_pca_dim_{pool_cfg.elliptical.pca_dim}"
            file_name += f"_elliptical_feature_{pool_cfg.elliptical.feature_name}_{pool_cfg.elliptical.feature_type}"
            if pool_cfg.elliptical.center_features:
                file_name += "_center_features"
        
        # Add sampling params
        file_name += f"_temp_{pool_cfg.sampling.temperature}"
        file_name += f"_top_p_{pool_cfg.sampling.top_p}"
        file_name += f"_min_p_{pool_cfg.sampling.min_p}"

    file_name += ".pdf"

    os.makedirs(os.path.join('figures', task_name, policy_name), exist_ok=True)
    return os.path.join('figures', task_name, policy_name, file_name)

def get_minimum_hardness_to_avg_samples_to_get_correct(coreset: Dict[int, float], all_results_ref: Dict[int, List[int]], hardness_style: str, delta: float = 0.05) -> Tuple[Dict[float, float], Dict[float, float]]:
    minimum_hardness_to_avg_samples_to_get_correct = dict()
    minimum_hardness_to_sem_samples_to_get_correct = dict()

    if hardness_style == "num_to_correct":
        prompt_idx_to_samples_to_get_correct = dict()
        for id, results in all_results_ref.items():
            prompt_idx_to_samples_to_get_correct[id] = (len(results) + 1) / (sum(results) + 1)

        for s in np.unique(list(prompt_idx_to_samples_to_get_correct.values())):
            minimum_hardness = s
            avg_samples_to_get_correct = np.mean([coreset[k] for k in coreset.keys() if prompt_idx_to_samples_to_get_correct[k] >= s])
            sem_samples_to_get_correct = stats.sem([coreset[k] for k in coreset.keys() if prompt_idx_to_samples_to_get_correct[k] >= s])
            minimum_hardness_to_avg_samples_to_get_correct[minimum_hardness] = avg_samples_to_get_correct
            minimum_hardness_to_sem_samples_to_get_correct[minimum_hardness] = sem_samples_to_get_correct

    elif hardness_style == "quantile_num_to_correct":
        if type(list(all_results_ref.values())[0]) == list:
            prompt_idx_to_samples_to_get_correct = dict()
            for id, results in all_results_ref.items():
                prompt_idx_to_samples_to_get_correct[id] = (len(results) + 1) / (sum(results) + 1)
        else:
            prompt_idx_to_samples_to_get_correct = all_results_ref

        # get prompt_idxs sorted by prompt_idx_to_samples_to_get_correct
        prompt_idxs = sorted(coreset.keys(), key=prompt_idx_to_samples_to_get_correct.get)

        for i in np.arange(0, 1.0, delta):
            begin_idx = int(i * len(prompt_idxs))
            end_idx = int((i + delta) * len(prompt_idxs))

            coreset_values = [coreset[k] for k in prompt_idxs[begin_idx:end_idx]]
            minimum_hardness_to_avg_samples_to_get_correct[i] = np.mean(coreset_values)
            minimum_hardness_to_sem_samples_to_get_correct[i] = stats.sem(coreset_values)
            
    elif hardness_style == "one_minus_pass@1":
        pass_at_1 = compute_pass_at_1(all_results_ref)
        for p in np.unique(list(pass_at_1.values())):
            minimum_hardness = 1 - p
            avg_samples_to_get_correct = np.mean([coreset[k] for k in coreset.keys() if pass_at_1[k] <= p])
            sem_samples_to_get_correct = stats.sem([coreset[k] for k in coreset.keys() if pass_at_1[k] <= p])
            minimum_hardness_to_avg_samples_to_get_correct[minimum_hardness] = avg_samples_to_get_correct
            minimum_hardness_to_sem_samples_to_get_correct[minimum_hardness] = sem_samples_to_get_correct

    return minimum_hardness_to_avg_samples_to_get_correct, minimum_hardness_to_sem_samples_to_get_correct

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    set_seeds(cfg.seed)

    # Build dataloader
    dl = dataloader_factory(cfg.task.name, cfg)

    # Load pool data
    pool_1_data = load_pool_data(cfg, cfg.plot.pool_1, dl, remove_all_incorrect=True)
    pool_2_data = load_pool_data(cfg, cfg.plot.pool_2, dl, remove_all_incorrect=True)
    pool_data_list = [pool_1_data, pool_2_data]
    pool_cfg_list = [cfg.plot.pool_1, cfg.plot.pool_2]

    # Log warning if one pool solves more problems than the other
    if len(pool_1_data["results"]) > len(pool_2_data["results"]):
        log.info(f"WARNING: Pool 1 solves more problems than pool 2: {len(pool_1_data['results'])} > {len(pool_2_data['results'])}")
    elif len(pool_1_data["results"]) < len(pool_2_data["results"]):
        log.info(f"WARNING: Pool 1 solves fewer problems than pool 2: {len(pool_1_data['results'])} < {len(pool_2_data['results'])}")
    
    # Only keep the overlapping problems in both pools
    common_ids = set(pool_1_data["results"].keys()).intersection(set(pool_2_data["results"].keys()))
    for key in pool_1_data:
        if pool_1_data[key] is not None:
            pool_1_data[key] = {k: v for k, v in pool_1_data[key].items() if k in common_ids}
    for key in pool_2_data:
        if pool_2_data[key] is not None:
            pool_2_data[key] = {k: v for k, v in pool_2_data[key].items() if k in common_ids}

    log.info(f"Number of problems in pool 1: {len(pool_1_data['results'])}")
    log.info(f"Number of problems in pool 2: {len(pool_2_data['results'])}")

    # Load reference data
    ref_policy = cfg.ref_policy.name
    ref_load_path = os.path.join(cfg.io.load_root, 'data', cfg.task.name, ref_policy)
    ref_response_data = load_ref_data(load_path=ref_load_path)
    all_results_ref = {k: v["results_ref"] for k, v in ref_response_data.items()}

    # Compute vanilla reference samples-to-correct
    vanilla_ref_samples_to_get_correct = dict()
    for id, results in all_results_ref.items():
        vanilla_ref_samples_to_get_correct[id] = (len(results) + 1) / (sum(results) + 1)

    # Save all ref results to file
    vanilla_ref_data_name = construct_vanilla_data_name(
        task_name=cfg.task.name,
        policy_name=cfg.policy.name,
        temp=1.0, # ref policy always has temp 1.0
        top_p=1.0, # ref policy always has top_p 1.0
        min_p=0.0, # ref policy always has min_p 0.0
        ref_policy_name=cfg.ref_policy.name,
        debug=cfg.debug,
    )
    json_dump(vanilla_ref_samples_to_get_correct, vanilla_ref_data_name)

    coresets_to_plot = []
    plot_labels = []

    for coreset, pool_data, pool_cfg in zip(cfg.plot.coresets, pool_data_list, pool_cfg_list):
        # Extract data
        all_responses = pool_data["responses"]
        all_answers = pool_data["answers"]
        all_results = pool_data["results"]
        all_features = pool_data["features"]
        all_prompts = pool_data["prompts"]

        if coreset == "elliptical":
            elliptical_coreset = EllipticalCoreset(cfg)
            elliptical_samples_to_get_correct, elliptical_samples_to_get_correct_all = elliptical_coreset.samples_to_get_correct(
                all_answers, 
                all_results, 
                all_features, 
                all_responses, 
                all_prompts
            )
            coresets_to_plot.append(elliptical_samples_to_get_correct)
            plot_labels.append("Elliptical")

            # save elliptical_samples_to_get_correct to file
            elliptical_data_name = construct_elliptical_data_name(
                task_name=cfg.task.name,
                policy_name=cfg.policy.name,
                temp=pool_cfg.sampling.temperature,
                top_p=pool_cfg.sampling.top_p,
                min_p=pool_cfg.sampling.min_p,
                sparse_dim=pool_cfg.elliptical.sparse_dim,
                perform_sparse_projection=pool_cfg.elliptical.perform_sparse_projection,
                feature_name=pool_cfg.elliptical.feature_name,
                feature_type=pool_cfg.elliptical.feature_type,
                perform_pca=pool_cfg.elliptical.perform_pca,
                pca_dim=pool_cfg.elliptical.pca_dim,
                lamb=pool_cfg.elliptical.lamb,
                center_features=pool_cfg.elliptical.center_features,
                debug=cfg.debug,
            )
            with open(elliptical_data_name, "w") as f:
                json.dump(elliptical_samples_to_get_correct_all, f)

        elif coreset == "vanilla":
            vanilla_samples_to_get_correct = dict()
            for example_id, responses in all_results.items():
                # NOTE: (n + 1) / (c + 1)
                vanilla_samples_to_get_correct[example_id] = (len(responses) + 1) / (sum(responses) + 1)
            coresets_to_plot.append(vanilla_samples_to_get_correct)
            plot_labels.append("Vanilla")

            # save vanilla_samples_to_get_correct to file
            vanilla_data_name = construct_vanilla_data_name(
                task_name=cfg.task.name,
                policy_name=cfg.policy.name,
                temp=pool_cfg.sampling.temperature,
                top_p=pool_cfg.sampling.top_p,
                min_p=pool_cfg.sampling.min_p,
                debug=cfg.debug,
            )
            with open(vanilla_data_name, "w") as f:
                json.dump(vanilla_samples_to_get_correct, f)

    # Plot all coresets
    sns.set_theme(style="whitegrid")
    for hardness_style in ["num_to_correct", "one_minus_pass@1", "quantile_num_to_correct"]:
        for coreset, label in zip(coresets_to_plot, plot_labels):

            minimum_hardness_to_avg_samples_to_get_correct, minimum_hardness_to_sem_samples_to_get_correct = get_minimum_hardness_to_avg_samples_to_get_correct(coreset, all_results_ref, hardness_style=hardness_style)

            xs = np.array(list(minimum_hardness_to_avg_samples_to_get_correct.keys()))
            ys = np.array([minimum_hardness_to_avg_samples_to_get_correct[x] for x in xs])
            sems = np.array([minimum_hardness_to_sem_samples_to_get_correct[x] for x in xs])
            plt.plot(
                xs, 
                ys, 
                # marker=MARKER,
                # markersize=MARKERSIZE,
                linewidth=LINEWIDTH,
                # markeredgewidth=MARKEREDGEWIDTH,
                # markeredgecolor=FACE_COLOR, 
                label=label
            )
            plt.fill_between(
                xs, 
                ys - sems, 
                ys + sems, 
                alpha=0.2
            )

            log.info(f"{label}: {coreset}")

        # Specify plot details
        plt.title(f"{cfg.task.name}: {hardness_style} hardness ({cfg.policy.name}, ref: {ref_policy})")
        if hardness_style == "num_to_correct":
            plt.xscale('log', base=10)
        plt.xlabel('Minimum hardness' if not 'quantile' in hardness_style else 'Hardness quantile')
        plt.ylabel('Samples-to-correct')
        plt.legend(fontsize=10)

        if 'quantile' in hardness_style:
            plt.xticks(np.arange(0, 1.05, 0.05), ['0.00', '', '0.10', '', '0.20', '', '0.30', '', '0.40', '', '0.50', '', '0.60', '', '0.70', '', '0.80', '', '0.90', '', '1.00'])

        plot_save_path = construct_plot_save_path_name(
            task_name=cfg.task.name,
            policy_name=cfg.policy.name,
            coresets=cfg.plot.coresets,
            pool_cfg_list=pool_cfg_list,
            hardness_style=hardness_style,
            ref_policy_name=ref_policy,
            debug=cfg.debug
        )
        print(plot_save_path)
        plt.savefig(plot_save_path)
        plt.close()

if __name__ == "__main__":
    main()