import os
import logging

import numpy as np
import hydra
from omegaconf import OmegaConf, DictConfig
from importlib import import_module
import matplotlib.pyplot as plt
import seaborn as sns

from inference_rlhf.code.helpers.utils import estimate_pass_at_k
from inference_rlhf.code.helpers.utils import load_response_data, maybe_filter_data
from inference_rlhf.code.helpers.utils import set_seeds

log = logging.getLogger(__name__)

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    set_seeds(cfg.seed)

    log.info(f"Loading {cfg.task.name} dataset ...")
    data_module = import_module(f"inference_rlhf.code.tasks.{cfg.task.name}",  package='inference_rlhf.code')
    dl = data_module.DataLoader(cfg)
    log.info(f"Done loading {cfg.task.name} dataset.")

    # Create path to generation files
    load_path = cfg.io.load_root    
    policy = cfg.policy.name
    load_path = os.path.join(load_path, 'data', cfg.task.name, policy, 'generations')

    # Load & extract response data
    response_data = load_response_data(
        load_path, 
        cfg.reward.name, 
        max_num_files=(5 if cfg.debug else None), 
        load_reward_scores=cfg.plot.load_reward_scores,
        load_features=cfg.plot.load_features,
        load_gradients=cfg.plot.load_gradients,
        feature_name=cfg.coreset.elliptical.feature_name,
        feature_type=cfg.coreset.elliptical.feature_type
    )
    all_responses = {k: v["responses"] for k, v in response_data.items()}
    all_answers = {k: v["answers"] for k, v in response_data.items()}
    all_results = {k: v["results"] for k, v in response_data.items()}

    if cfg.plot.load_gradients:
        all_gradients = {k: v["gradients"][0] for k, v in response_data.items()}

    if cfg.plot.load_reward_scores:
        all_reward_scores = {k: v["reward_scores"] for k, v in response_data.items()}

    all_log_probs = {k: v["log_probs"] for k, v in response_data.items()}
        
    if cfg.plot.load_features:
        all_features = {k: v["features"][0] for k, v in response_data.items()}

    all_prompts = {k: dl.questions[k] for k, v in response_data.items()}

    # Potentially filter data
    data = maybe_filter_data(
        all_responses, 
        all_answers, 
        all_results, 
        all_reward_scores if cfg.plot.load_reward_scores else None, 
        all_log_probs,
        all_gradients if cfg.plot.load_gradients else None,
        all_features if cfg.plot.load_features else None, 
        all_prompts, 
        max_n=cfg.plot.max_n, 
        subsample_size=cfg.plot.subsample_size
    )
    all_responses, all_answers, all_results, all_reward_scores, all_log_probs, all_gradients, all_features, all_prompts = data

    # Log basic statistics
    avg_unique_answers = np.mean([len(set(answers)) for answers in all_answers.values()])
    max_unique_answers = max([len(set(answers)) for answers in all_answers.values()])
    min_unique_answers = min([len(set(answers)) for answers in all_answers.values()])
    log.info(f"Avg unique answers: {avg_unique_answers:.1f}, Max unique answers: {max_unique_answers:d}, Min unique answers: {min_unique_answers:d}")

    # Plot histogram of unique answers
    sns.set_theme(style="whitegrid")
    plt.hist([len(set(answers)) for answers in all_answers.values()], bins=range(1, max_unique_answers + 1), edgecolor='black')
    plt.title(f"Histogram of unique answers ({policy})")
    plt.xlabel("Number of unique answers")
    plt.ylabel("Frequency")
    plt.savefig(f"figures/analysis/unique_answers_histogram_{policy}.pdf")
    plt.close()

    # Plot cumulative distribution of unique answers, normalize
    plt.hist(
        [len(set(answers)) for answers in all_answers.values()], 
        bins=range(1, max_unique_answers + 1), 
        cumulative=True, 
        density=True,
        color='#016EB7',
        edgecolor='black'
    )
    plt.title(f"Cumulative distribution of unique answers ({policy})")
    plt.xlabel("Number of unique answers")
    plt.ylabel("Cumulative density")
    plt.savefig(f"figures/analysis/unique_answers_cumulative_distribution_{policy}.pdf")
    plt.close()

    # Plot histogram of the pass@1 for each question
    pass_at_1 = []
    prompt_idxs = []
    for example_id, responses in all_results.items():
        pass_at_1.append(estimate_pass_at_k([len(responses)], [sum(responses)], 1)[0])
        prompt_idxs.append(example_id)
    plt.hist(pass_at_1, bins=np.linspace(0, 1, 11), edgecolor='black')
    # plot mean line
    plt.axvline(np.mean(pass_at_1), color='red', linestyle='--')
    # set x ticks to be 0, 0.1, 0.2, ..., 1
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.title(f"Histogram of pass@1 ({policy})")
    plt.xlabel("Pass@1")
    plt.ylabel("Frequency")
    plt.savefig(f"figures/analysis/pass_at_1_histogram_{policy}.pdf")
    plt.close()

    # match pass@1 to prompt idx
    pass_at_1_with_prompt_idx = list(zip(pass_at_1, prompt_idxs))
    pass_at_1_with_prompt_idx.sort(key=lambda x: x[0])
    print('done!')

if __name__ == "__main__":
    main()