import os
import logging
from collections import defaultdict
from typing import Dict, List       
import random

import numpy as np
import hydra
from omegaconf import OmegaConf, DictConfig
from importlib import import_module
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from tqdm import tqdm

from inference_rlhf.code.helpers.utils import timing, estimate_pass_at_k
from inference_rlhf.code.coreset.unique_random_coreset import UniqueRandomCoreset
from inference_rlhf.code.coreset.unique_frequency_coreset import UniqueFrequencyCoreset
from inference_rlhf.code.coreset.reward_model_coreset import RewardModelCoreset
from inference_rlhf.code.coreset.reward_model_unique_coreset import RewardModelUniqueCoreset
from inference_rlhf.code.coreset.elliptical_coreset import EllipticalCoreset
from inference_rlhf.code.coreset.log_probs import LogProbsCoreset
from inference_rlhf.code.coreset.llm_binary_quality_diversity_coreset import LLMBinaryQualityDiversityCoreset
from inference_rlhf.code.helpers.utils import load_response_data, maybe_filter_data
from inference_rlhf.code.helpers.utils import set_seeds
from inference_rlhf.code.coreset.gradient_norms import GradientNormsCoreset
from inference_rlhf.code.coreset.vanilla import VanillaCoreset
from inference_rlhf.code.tasks.math import extract_answer

log = logging.getLogger(__name__)

FACE_COLOR = "#F7F7FF"
MARKER = "o"
LINEWIDTH = 1.2
MARKERSIZE = 10
MARKEREDGEWIDTH = 1.0

@timing
def compute_pass_at_k(all_results: Dict[int, List[int]], max_k: int) -> Dict[int, List[float]]:
    pass_at_k = defaultdict(list)
    k_powers = [2**i for i in range(int(np.log2(max_k)) + 1)]
    for k in k_powers:
        for example_id, responses in all_results.items():
            pass_at_k[k].append(estimate_pass_at_k([len(responses)], [sum(responses)], k))
    return pass_at_k

def avg_pass_at_k(all_pass_at_k: Dict[int, List[float]]) -> Dict[int, float]:
    return {k: np.mean(v) for k, v in all_pass_at_k.items()}

def plot_avg_reward_elliptical_coreset_at_k(avg_reward_at_k: Dict[int, List[float]], policy: str, reward_percent_to_filter: int, sparse_dim: int) -> None:
    sns.set_theme(style="whitegrid")
    plt.plot(
        avg_reward_at_k.keys(), 
        avg_reward_at_k.values(), 
        marker=MARKER,
        markersize=MARKERSIZE,
        linewidth=LINEWIDTH,
        markeredgewidth=MARKEREDGEWIDTH,
        markeredgecolor=FACE_COLOR, label="Elliptical"
    )

    log.info(f"Avg reward@k: {avg_reward_at_k}")

    # Specify plot details
    plt.title("MATH: coresets")
    plt.xscale('log', base=2)
    plt.xlabel('k')
    plt.ylabel('Avg Reward@k')
    plt.legend(fontsize=10)
    plt.ylim(min(avg_reward_at_k.values()) - 0.05, max(avg_reward_at_k.values()) + 0.05)
    x_ticks = [2**i for i in range(int(np.log2(max(avg_reward_at_k.keys()))) + 1)]
    x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(avg_reward_at_k.keys()))) + 1)]
    plt.xticks(x_ticks, x_tick_labels)

    plt.savefig(f"figures/avg_reward_at_k_{policy}_{reward_percent_to_filter}_{sparse_dim}.pdf")
    plt.close()

def construct_plot_save_path_name(
    policy_name: str, 
    coresets: str,
    reward_percent_to_filter: int,
    log_probs_percent_to_filter: int,
    sparse_dim: int,
    perform_sparse_projection: bool,
    judge_name: str,
    log_probs_percent_to_filter_vanilla: int,
    feature_name: str,
    feature_type: str,
    perform_pca: bool,
    pca_dim: int,
    scale_features_with_log_probs: bool,
    gradient_percent_to_filter_vanilla: int,
    use_gradients: bool,
    argmax: bool,
    lamb: float,
    alpha: float,
    use_weird_sampling: bool,
    temp: float,
    center_features: bool,
    use_weird_sampling2: bool,
    gradient_norms_mode: str,
    reparse_answers: bool
) -> str:
    file_name = f"unique_answers_at_k_{policy_name}_{coresets}"

    if "elliptical" in coresets:
        file_name += f"_lamb_{lamb}"
        if perform_sparse_projection:
            file_name += f"_sparse_dim_{sparse_dim}"
        if perform_pca:
            file_name += f"_pca_dim_{pca_dim}"
        file_name += f"_elliptical_r%_{reward_percent_to_filter}"
        file_name += f"_elliptical_lp%_{log_probs_percent_to_filter}"
        file_name += f"_elliptical_feature_{feature_name}_{feature_type}"
        if scale_features_with_log_probs:
            file_name += "_scale_features_with_log_probs"
            file_name += f"_temp_{temp}"
        if use_gradients:
            file_name += "_use_gradients"
        if argmax:
            file_name += "_argmax"
        else:
            file_name += "_sample"
        if center_features:
            file_name += "_center_features"

        if alpha > 0.0:
            file_name += f"_alpha_{alpha}"

        if use_weird_sampling:
            file_name += "_use_weird_sampling"
        if use_weird_sampling2:
            file_name += "_use_weird_sampling2"

    if "llm_binary_quality_diversity" in coresets:
        file_name += f"_judge_{judge_name}"

    if "filtered_vanilla" in coresets:
        file_name += f"_vanilla_lp%_{log_probs_percent_to_filter_vanilla}"
        file_name += f"_vanilla_g%_{gradient_percent_to_filter_vanilla}"

    if "gradient_norms" in coresets:
        file_name += f"_gradient_norms_mode_{gradient_norms_mode}"

    if reparse_answers:
        file_name += "_reparse_answers"

    file_name += ".pdf"

    return os.path.join('figures', file_name)

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    set_seeds(cfg.seed)

    log.info(f"Loading {cfg.task.name} dataset ...")
    data_module = import_module(f"inference_rlhf.code.tasks.{cfg.task.name}",  package='inference_rlhf.code')
    dl = data_module.DataLoader(cfg)
    log.info(f"Done loading {cfg.task.name} dataset.")

    # Create path to generation files
    load_path = cfg.io.load_root    
    policy = cfg.policy.name
    load_path = os.path.join(load_path, 'data', cfg.task.name, policy, 'generations')

    # Load & extract response data
    response_data = load_response_data(
        load_path, 
        cfg.reward.name, 
        max_num_files=(5 if cfg.debug else None), 
        load_reward_scores=cfg.plot.load_reward_scores,
        load_features=cfg.plot.load_features,
        load_gradients=cfg.plot.load_gradients,
        feature_name=cfg.coreset.elliptical.feature_name,
        feature_type=cfg.coreset.elliptical.feature_type
    )
    all_responses = {k: v["responses"] for k, v in response_data.items()}
    all_answers = {k: v["answers"] for k, v in response_data.items()}
    all_results = {k: v["results"] for k, v in response_data.items()}

    if cfg.plot.load_gradients:
        all_gradients = {k: v["gradients"][0] for k, v in response_data.items()}

    if cfg.plot.load_reward_scores:
        all_reward_scores = {k: v["reward_scores"] for k, v in response_data.items()}

    all_log_probs = {k: v["log_probs"] for k, v in response_data.items()}
        
    if cfg.plot.load_features:
        all_features = {k: v["features"][0] for k, v in response_data.items()}

    all_prompts = {k: dl.questions[k] for k, v in response_data.items()}

    # Potentially filter data
    data = maybe_filter_data(
        all_responses, 
        all_answers, 
        all_results, 
        all_reward_scores if cfg.plot.load_reward_scores else None, 
        all_log_probs,
        all_gradients if cfg.plot.load_gradients else None,
        all_features if cfg.plot.load_features else None, 
        all_prompts, 
        max_n=cfg.plot.max_n, 
        subsample_size=cfg.plot.subsample_size
    )
    all_responses, all_answers, all_results, all_reward_scores, all_log_probs, all_gradients, all_features, all_prompts = data

    if cfg.plot.reparse_answers:
        reparsed_answers = {k: [extract_answer(response, strict=True) for response in v if extract_answer(response, strict=True) is not None] for k, v in all_responses.items()}
        # filter out responses and features that have no answer
        answer_idxs = {k: [i for i, response in enumerate(v) if extract_answer(response, strict=True) is not None] for k, v in all_responses.items()}
        all_responses = {k: [v[i] for i in answer_idxs[k]] for k, v in all_responses.items()}
        all_features = {k: torch.stack([all_features[k][i] for i in answer_idxs[k]]) for k, v in all_features.items()}
        all_gradients = {k: torch.stack([all_gradients[k][i] for i in answer_idxs[k]]) for k, v in all_gradients.items()}
        all_log_probs = {k: [all_log_probs[k][i] for i in answer_idxs[k]] for k, v in all_log_probs.items()}
        all_answers = reparsed_answers

    # Log basic statistics
    avg_unique_answers = np.mean([len(set(answers)) for answers in all_answers.values()])
    max_unique_answers = max([len(set(answers)) for answers in all_answers.values()])
    min_unique_answers = min([len(set(answers)) for answers in all_answers.values()])
    log.info(f"Avg unique answers: {avg_unique_answers:.1f}, Max unique answers: {max_unique_answers:d}, Min unique answers: {min_unique_answers:d}")

    coresets_to_plot = []
    plot_labels = []

    for coreset in cfg.plot.coresets:
        if coreset == "elliptical":
            elliptical_coreset = EllipticalCoreset(cfg)
            elliptical_mean_unique_answers_at_k = elliptical_coreset.unique_answers_at_k(
                all_answers, all_features, all_gradients, all_reward_scores, all_log_probs, all_responses, all_prompts, dl.answers
            )
            coresets_to_plot.append(elliptical_mean_unique_answers_at_k)
            if cfg.coreset.elliptical.scale_features_with_log_probs:
                plot_labels.append(f"Elliptical ({cfg.coreset.elliptical.temp})")
            else:
                plot_labels.append("Elliptical")

        elif coreset == "vanilla":
            vanilla_coreset = VanillaCoreset(cfg)
            vanilla_mean_unique_answers_at_k = vanilla_coreset.unique_answers_at_k(
                all_answers, dl.answers
            )
            coresets_to_plot.append(vanilla_mean_unique_answers_at_k)
            plot_labels.append("Vanilla")

    # get maximum unique for each k
    k_to_max_unique = defaultdict(list)
    for k in coresets_to_plot[0].keys():
        for id, answers in all_answers.items():
            k_to_max_unique[k].append(min(len(set(answers)), k))

    # average over all k
    k_to_max_unique = {k: np.mean(v) for k, v in k_to_max_unique.items()}

    # Plot all coresets
    sns.set_theme(style="whitegrid")
    for coreset, label in zip(coresets_to_plot, plot_labels):
        plt.plot(
            coreset.keys(), 
            coreset.values(), 
            marker=MARKER,
            markersize=MARKERSIZE,
            linewidth=LINEWIDTH,
            markeredgewidth=MARKEREDGEWIDTH,
            markeredgecolor=FACE_COLOR, label=label
        )

        log.info(f"{label}@k: {coreset}")

    # Plot maximum unique for each k
    plt.plot(
        k_to_max_unique.keys(), 
        k_to_max_unique.values(), 
        linestyle='--',
        color='black',
        label="Max unique"
    )

    # Specify plot details
    plt.title("MATH: unique answer growth")
    plt.xscale('log', base=2)
    plt.xlabel('k')
    plt.ylabel('Unique answers')
    plt.legend(fontsize=10)
    x_ticks = [2**i for i in range(int(np.log2(max(vanilla_mean_unique_answers_at_k.keys()))) + 1)]
    x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(vanilla_mean_unique_answers_at_k.keys()))) + 1)]
    plt.xticks(x_ticks, x_tick_labels)

    plot_save_path = construct_plot_save_path_name(
        policy, 
        cfg.plot.coresets, 
        cfg.coreset.elliptical.reward_percent_to_filter, 
        cfg.coreset.elliptical.log_probs_percent_to_filter,
        cfg.coreset.elliptical.sparse_dim,
        cfg.coreset.elliptical.perform_sparse_projection,
        cfg.coreset.llm_binary_quality_diversity.model_name,
        cfg.coreset.vanilla.log_probs_percent_to_filter,
        cfg.coreset.elliptical.feature_name,
        cfg.coreset.elliptical.feature_type,
        cfg.coreset.elliptical.perform_pca,
        cfg.coreset.elliptical.pca_dim,
        cfg.coreset.elliptical.scale_features_with_log_probs,
        cfg.coreset.vanilla.gradient_percent_to_filter,
        cfg.coreset.elliptical.use_gradients,
        cfg.coreset.elliptical.argmax,
        cfg.coreset.elliptical.lamb,
        cfg.coreset.elliptical.alpha,
        cfg.coreset.elliptical.use_weird_sampling,
        cfg.coreset.elliptical.temp,
        cfg.coreset.elliptical.center_features,
        cfg.coreset.elliptical.use_weird_sampling2,
        cfg.coreset.gradient_norms.mode,
        cfg.plot.reparse_answers
    )
    print(plot_save_path)
    plt.savefig(plot_save_path)
    plt.close()

if __name__ == "__main__":
    main()