import os
import logging
from collections import defaultdict
from typing import Dict, List       
import random
import copy

import numpy as np
import hydra
from omegaconf import OmegaConf, DictConfig
from importlib import import_module
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from tqdm import tqdm

from inference_rlhf.code.helpers.utils import timing, estimate_pass_at_k
from inference_rlhf.code.coreset.unique_random_coreset import UniqueRandomCoreset
from inference_rlhf.code.coreset.unique_frequency_coreset import UniqueFrequencyCoreset
from inference_rlhf.code.coreset.reward_model_coreset import RewardModelCoreset
from inference_rlhf.code.coreset.reward_model_unique_coreset import RewardModelUniqueCoreset
from inference_rlhf.code.coreset.elliptical_coreset import EllipticalCoreset
from inference_rlhf.code.coreset.log_probs import LogProbsCoreset
from inference_rlhf.code.coreset.llm_binary_quality_diversity_coreset import LLMBinaryQualityDiversityCoreset
from inference_rlhf.code.coreset.llm_direct_coreset import LLMDirectCoreset
from inference_rlhf.code.helpers.utils import load_response_data, maybe_filter_data, dataloader_factory
from inference_rlhf.code.helpers.utils import set_seeds
from inference_rlhf.code.coreset.gradient_norms import GradientNormsCoreset

log = logging.getLogger(__name__)

FACE_COLOR = "#F7F7FF"
MARKER = "o"
LINEWIDTH = 1.2
MARKERSIZE = 10
MARKEREDGEWIDTH = 1.0

@timing
def compute_pass_at_k(all_results: Dict[int, List[int]], max_k: int) -> Dict[int, List[float]]:
    pass_at_k = defaultdict(list)
    k_powers = [2**i for i in range(int(np.log2(max_k)) + 1)]
    for k in k_powers:
        for example_id, responses in all_results.items():
            pass_at_k[k].append(estimate_pass_at_k([len(responses)], [sum(responses)], k)[0])
    return pass_at_k

def avg_pass_at_k(all_pass_at_k: Dict[int, List[float]]) -> Dict[int, float]:
    return {k: np.mean(v) for k, v in all_pass_at_k.items()}

def plot_avg_reward_elliptical_coreset_at_k(avg_reward_at_k: Dict[int, List[float]], policy: str, reward_percent_to_filter: int, sparse_dim: int) -> None:
    sns.set_theme(style="whitegrid")
    plt.plot(
        avg_reward_at_k.keys(), 
        avg_reward_at_k.values(), 
        marker=MARKER,
        markersize=MARKERSIZE,
        linewidth=LINEWIDTH,
        markeredgewidth=MARKEREDGEWIDTH,
        markeredgecolor=FACE_COLOR, label="Elliptical"
    )

    log.info(f"Avg reward@k: {avg_reward_at_k}")

    # Specify plot details
    plt.title("MATH: coresets")
    plt.xscale('log', base=2)
    plt.xlabel('k')
    plt.ylabel('Avg Reward@k')
    plt.legend(fontsize=10)
    plt.ylim(min(avg_reward_at_k.values()) - 0.05, max(avg_reward_at_k.values()) + 0.05)
    x_ticks = [2**i for i in range(int(np.log2(max(avg_reward_at_k.keys()))) + 1)]
    x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(avg_reward_at_k.keys()))) + 1)]
    plt.xticks(x_ticks, x_tick_labels)

    plt.savefig(f"figures/avg_reward_at_k_{policy}_{reward_percent_to_filter}_{sparse_dim}.pdf")
    plt.close()

def construct_plot_save_path_name(
    task_name: str,
    policy_name: str, 
    coresets: str,
    reward_percent_to_filter: int,
    log_probs_percent_to_filter: int,
    sparse_dim: int,
    perform_sparse_projection: bool,
    judge_name: str,
    log_probs_percent_to_filter_vanilla: int,
    feature_name: str,
    feature_type: str,
    perform_pca: bool,
    pca_dim: int,
    scale_features_with_log_probs: bool,
    gradient_percent_to_filter_vanilla: int,
    use_gradients: bool,
    argmax: bool,
    lamb: float,
    alpha: float,
    use_weird_sampling: bool,
    temp: float,
    center_features: bool,
    use_weird_sampling2: bool,
    gradient_norms_mode: str
) -> str:
    file_name = f"pass_at_k_{policy_name}_{coresets}"

    if "elliptical" in coresets:
        file_name += f"_lamb_{lamb}"
        if perform_sparse_projection:
            file_name += f"_sparse_dim_{sparse_dim}"
        if perform_pca:
            file_name += f"_pca_dim_{pca_dim}"
        file_name += f"_elliptical_r%_{reward_percent_to_filter}"
        file_name += f"_elliptical_lp%_{log_probs_percent_to_filter}"
        file_name += f"_elliptical_feature_{feature_name}_{feature_type}"
        if scale_features_with_log_probs:
            file_name += "_scale_features_with_log_probs"
            file_name += f"_temp_{temp}"
        if use_gradients:
            file_name += "_use_gradients"
        if argmax:
            file_name += "_argmax"
        else:
            file_name += "_sample"
        if center_features:
            file_name += "_center_features"

        if alpha > 0.0:
            file_name += f"_alpha_{alpha}"

        if use_weird_sampling:
            file_name += "_use_weird_sampling"
        if use_weird_sampling2:
            file_name += "_use_weird_sampling2"

    if "llm_binary_quality_diversity" in coresets:
        file_name += f"_judge_{judge_name}"

    if "filtered_vanilla" in coresets:
        file_name += f"_vanilla_lp%_{log_probs_percent_to_filter_vanilla}"
        file_name += f"_vanilla_g%_{gradient_percent_to_filter_vanilla}"

    if "gradient_norms" in coresets:
        file_name += f"_gradient_norms_mode_{gradient_norms_mode}"

    file_name += ".pdf"

    return os.path.join('figures', task_name, file_name)

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    set_seeds(cfg.seed)

    dl = dataloader_factory(cfg.task.name, cfg)

    # Create path to generation files
    load_path = cfg.io.load_root    
    policy = cfg.policy.name
    load_path = os.path.join(load_path, 'data', cfg.task.name, policy)

    # Load & extract response data
    response_data = load_response_data(
        policy_name=policy,
        load_path=load_path, 
        reward_name=cfg.reward.name, 
        max_num_files=(5 if cfg.debug else None), 
        load_reward_scores=cfg.plot.load_reward_scores,
        load_features=cfg.plot.load_features,
        load_gradients=cfg.plot.load_gradients,
        feature_name=cfg.coreset.elliptical.feature_name,
        feature_type=cfg.coreset.elliptical.feature_type,
        strict_results=True,
        answer_patterns=cfg.policy.answer_patterns if 'code_contests' not in load_path and 'mbpp' not in load_path else None,
        task_name=cfg.task.name,
    )
    all_responses = {k: v["responses"] for k, v in response_data.items()}
    all_answers = {k: v["answers"] for k, v in response_data.items()}
    all_results = {k: v["results"] for k, v in response_data.items()}

    print('Number of files: ', len(all_results))

    if cfg.plot.load_gradients:
        all_gradients = {k: v["gradients"][0] for k, v in response_data.items()}

    if cfg.plot.load_reward_scores:
        all_reward_scores = {k: v["reward_scores"] for k, v in response_data.items()}

    all_log_probs = {k: v["log_probs"] for k, v in response_data.items()}
        
    if cfg.plot.load_features:
        all_features = {k: v["features"][0] for k, v in response_data.items()}

    all_prompts = {k: dl.questions[k] for k, v in response_data.items()}

    # Potentially filter data
    data = maybe_filter_data(
        all_responses, 
        all_answers, 
        all_results, 
        all_reward_scores if cfg.plot.load_reward_scores else None, 
        all_log_probs,
        all_gradients if cfg.plot.load_gradients else None,
        all_features if cfg.plot.load_features else None, 
        all_prompts, 
        max_n=cfg.plot.max_n, 
        subsample_size=cfg.plot.subsample_size
    )
    all_responses, all_answers, all_results, all_reward_scores, all_log_probs, all_gradients, all_features, all_prompts = data

    # Log basic statistics
    avg_unique_answers = np.mean([len(set(answers)) for answers in all_answers.values()])
    max_unique_answers = max([len(set(answers)) for answers in all_answers.values()])
    min_unique_answers = min([len(set(answers)) for answers in all_answers.values()])
    log.info(f"Avg unique answers: {avg_unique_answers:.1f}, Max unique answers: {max_unique_answers:d}, Min unique answers: {min_unique_answers:d}")

    coresets_to_plot = []
    plot_labels = []

    for coreset in cfg.plot.coresets:
    # for temp in [-1, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]:
        # coreset = "vanilla" if temp == -1 else "elliptical"
        # cfg.coreset.elliptical.temp = temp
        if coreset == "unique_random":
            unique_random_coreset = UniqueRandomCoreset(cfg)
            mean_random_coreset_at_k = unique_random_coreset.pass_at_k(all_answers, dl.answers)
            coresets_to_plot.append(mean_random_coreset_at_k)
            plot_labels.append("Unique Random")

        elif coreset == "unique_frequency":
            unique_frequency_coreset = UniqueFrequencyCoreset(cfg)
            mean_frequency_coreset_at_k = unique_frequency_coreset.pass_at_k(all_answers, dl.answers)
            coresets_to_plot.append(mean_frequency_coreset_at_k)
            plot_labels.append("Unique Frequency")

        elif coreset == "reward_model":
            reward_model_coreset = RewardModelCoreset(cfg)
            mean_reward_model_coreset_at_k = reward_model_coreset.pass_at_k(all_answers, all_reward_scores, dl.answers)
            coresets_to_plot.append(mean_reward_model_coreset_at_k)
            plot_labels.append("Reward Model")

        elif coreset == "log_probs":
            log_probs_coreset = LogProbsCoreset(cfg)
            mean_log_probs_coreset_at_k = log_probs_coreset.pass_at_k(all_answers, all_log_probs, dl.answers)
            coresets_to_plot.append(mean_log_probs_coreset_at_k)
            plot_labels.append("Log Probs")

        elif coreset == "gradient_norms":
            gradient_norms_coreset = GradientNormsCoreset(cfg)
            mean_gradient_norms_coreset_at_k = gradient_norms_coreset.pass_at_k(all_answers, all_gradients, dl.answers)
            coresets_to_plot.append(mean_gradient_norms_coreset_at_k)
            plot_labels.append("Gradient Norms")

        elif coreset == "reward_model_unique":
            reward_model_unique_coreset = RewardModelUniqueCoreset(cfg)
            mean_reward_model_unique_coreset_at_k = reward_model_unique_coreset.pass_at_k(all_answers, all_reward_scores, dl.answers)
            coresets_to_plot.append(mean_reward_model_unique_coreset_at_k)
            plot_labels.append("Reward Model Unique")

        elif coreset == "elliptical":
            elliptical_coreset = EllipticalCoreset(cfg)
            mean_elliptical_coreset_at_k = elliptical_coreset.pass_at_k(all_answers, all_results, all_features, all_gradients, all_reward_scores, all_log_probs, all_responses, all_prompts, dl.answers)
            # avg_reward_elliptical_coreset_at_k = elliptical_coreset.avg_reward_at_k(all_reward_scores)
            # plot_avg_reward_elliptical_coreset_at_k(avg_reward_elliptical_coreset_at_k, policy, cfg.coreset.elliptical.reward_percent_to_filter, cfg.coreset.elliptical.sparse_dim)
            coresets_to_plot.append(mean_elliptical_coreset_at_k)
            if cfg.coreset.elliptical.scale_features_with_log_probs:
                plot_labels.append(f"Elliptical ({cfg.coreset.elliptical.temp})")
            elif cfg.coreset.elliptical.log_probs_percent_to_filter > 0:
                plot_labels.append(f"Elliptical ({cfg.coreset.elliptical.log_probs_percent_to_filter}%)")
            else:
                plot_labels.append("Elliptical")

        elif coreset == "filtered_random_vanilla":
            filtered_all_results = copy.deepcopy(all_results)
            if cfg.coreset.vanilla.log_probs_percent_to_filter > 0:
                for example_id, log_probs in all_log_probs.items():
                    num_to_filter = int(len(log_probs) * cfg.coreset.vanilla.log_probs_percent_to_filter / 100)
                    ignore_idxs = random.sample(range(len(log_probs)), num_to_filter)
                    filtered_all_results[example_id] = [result for i, result in enumerate(filtered_all_results[example_id]) if i not in ignore_idxs]

            filtered_random_vanilla_pass_at_k = compute_pass_at_k(filtered_all_results, cfg.coreset.max_k)
            mean_filtered_random_vanilla_pass_at_k = avg_pass_at_k(filtered_random_vanilla_pass_at_k)
            coresets_to_plot.append(mean_filtered_random_vanilla_pass_at_k)
            plot_labels.append("Filtered Random Vanilla")

        elif coreset == "filtered_vanilla":
            filtered_all_results = copy.deepcopy(all_results)
            if cfg.coreset.vanilla.log_probs_percent_to_filter > 0:
                # for each example, find indices with the lowest log probs
                for example_id, log_probs in all_log_probs.items():
                    sorted_indices = sorted(range(len(log_probs)), key=lambda i: log_probs[i])
                    num_to_filter = int(len(log_probs) * cfg.coreset.vanilla.log_probs_percent_to_filter / 100)
                    ignore_idxs = sorted_indices[:num_to_filter]
                    filtered_all_results[example_id] = [result for i, result in enumerate(filtered_all_results[example_id]) if i not in ignore_idxs]

            elif cfg.coreset.vanilla.gradient_percent_to_filter > 0:
                for example_id, gradients in tqdm(all_gradients.items(), desc="Filtering gradients ..."):
                    sorted_indices = sorted(range(len(gradients)), key=lambda i: torch.norm(gradients[i], p=1), reverse=True)
                    num_to_filter = int(len(gradients) * cfg.coreset.vanilla.gradient_percent_to_filter / 100)
                    ignore_idxs = sorted_indices[:num_to_filter]
                    filtered_all_results[example_id] = [result for i, result in enumerate(filtered_all_results[example_id]) if i not in ignore_idxs]

            filtered_vanilla_pass_at_k = compute_pass_at_k(filtered_all_results, cfg.coreset.max_k)
            mean_filtered_vanilla_pass_at_k = avg_pass_at_k(filtered_vanilla_pass_at_k)
            coresets_to_plot.append(mean_filtered_vanilla_pass_at_k)
            perc = cfg.coreset.vanilla.log_probs_percent_to_filter if cfg.coreset.vanilla.log_probs_percent_to_filter > 0 else cfg.coreset.vanilla.gradient_percent_to_filter
            if cfg.coreset.vanilla.log_probs_percent_to_filter > 0:
                plot_labels.append(f"Filtered Log probs Vanilla ({perc}%)")
            elif cfg.coreset.vanilla.gradient_percent_to_filter > 0:
                plot_labels.append(f"Filtered Gradient Vanilla ({perc}%)")

        elif coreset == "vanilla":
            all_pass_at_k = compute_pass_at_k(all_results, cfg.coreset.max_k)
            mean_pass_at_k = avg_pass_at_k(all_pass_at_k)
            coresets_to_plot.append(mean_pass_at_k)
            plot_labels.append("Vanilla")

        elif coreset == "llm_binary_quality_diversity":
            llm_binary_quality_diversity_coreset = LLMBinaryQualityDiversityCoreset(cfg)
            mean_llm_binary_quality_diversity_coreset_at_k = llm_binary_quality_diversity_coreset.pass_at_k(dl.questions, all_responses, all_answers, dl.answers, all_reward_scores)
            coresets_to_plot.append(mean_llm_binary_quality_diversity_coreset_at_k)
            plot_labels.append("LLM Binary Quality Diversity")

        elif coreset == "llm_direct_coreset":
            llm_direct_coreset = LLMDirectCoreset(cfg, load_path)
            mean_llm_direct_coreset_at_k = llm_direct_coreset.pass_at_k(all_responses, dl.answers)
            coresets_to_plot.append(mean_llm_direct_coreset_at_k)
            plot_labels.append("LLM Direct Coreset")

    # Plot all coresets
    sns.set_theme(style="whitegrid")
    for coreset, label in zip(coresets_to_plot, plot_labels):
        plt.plot(
            coreset.keys(), 
            coreset.values(), 
            marker=MARKER,
            markersize=MARKERSIZE,
            linewidth=LINEWIDTH,
            markeredgewidth=MARKEREDGEWIDTH,
            markeredgecolor=FACE_COLOR, label=label
        )

        log.info(f"{label}@k: {coreset}")

    # Specify plot details
    plt.title(f"{cfg.task.name}: coresets ({cfg.policy.name})")
    plt.xscale('log', base=2)
    plt.xlabel('k')
    plt.ylabel('Pass@k')
    plt.legend(fontsize=10)
    plt.ylim(min(mean_pass_at_k.values()) - 0.05, 1.0)
    x_ticks = [2**i for i in range(int(np.log2(max(mean_pass_at_k.keys()))) + 1)]
    x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(mean_pass_at_k.keys()))) + 1)]
    plt.xticks(x_ticks, x_tick_labels)

    plot_save_path = construct_plot_save_path_name(
        cfg.task.name,
        policy, 
        cfg.plot.coresets, 
        cfg.coreset.elliptical.reward_percent_to_filter, 
        cfg.coreset.elliptical.log_probs_percent_to_filter,
        cfg.coreset.elliptical.sparse_dim,
        cfg.coreset.elliptical.perform_sparse_projection,
        cfg.coreset.llm_binary_quality_diversity.model_name,
        cfg.coreset.vanilla.log_probs_percent_to_filter,
        cfg.coreset.elliptical.feature_name,
        cfg.coreset.elliptical.feature_type,
        cfg.coreset.elliptical.perform_pca,
        cfg.coreset.elliptical.pca_dim,
        cfg.coreset.elliptical.scale_features_with_log_probs,
        cfg.coreset.vanilla.gradient_percent_to_filter,
        cfg.coreset.elliptical.use_gradients,
        cfg.coreset.elliptical.argmax,
        cfg.coreset.elliptical.lamb,
        cfg.coreset.elliptical.alpha,
        cfg.coreset.elliptical.use_weird_sampling,
        cfg.coreset.elliptical.temp,
        cfg.coreset.elliptical.center_features,
        cfg.coreset.elliptical.use_weird_sampling2,
        cfg.coreset.gradient_norms.mode
    )
    print(plot_save_path)
    plt.savefig(plot_save_path)
    plt.close()

if __name__ == "__main__":
    main()