import os
import logging
from collections import defaultdict

import numpy as np
import hydra
from omegaconf import OmegaConf, DictConfig
from importlib import import_module
import matplotlib.pyplot as plt
import seaborn as sns

from inference_rlhf.code.helpers.utils import load_response_data, maybe_filter_data
from inference_rlhf.code.helpers.utils import set_seeds

log = logging.getLogger(__name__)

FACE_COLOR = "#F7F7FF"
MARKER = "o"
LINEWIDTH = 1.2
MARKERSIZE = 10
MARKEREDGEWIDTH = 1.0

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    set_seeds(cfg.seed)

    log.info(f"Loading {cfg.task.name} dataset ...")
    data_module = import_module(f"inference_rlhf.code.tasks.{cfg.task.name}",  package='inference_rlhf.code')
    dl = data_module.DataLoader(cfg)
    log.info(f"Done loading {cfg.task.name} dataset.")

    # Create path to generation files
    load_path = cfg.io.load_root    
    policy = cfg.policy.name
    load_path = os.path.join(load_path, 'data', cfg.task.name, policy, 'generations')

    # Load & extract response data
    response_data = load_response_data(load_path, cfg.reward.name, max_num_files=(5 if cfg.debug else None))
    all_responses = {k: v["responses"] for k, v in response_data.items()}
    all_answers = {k: v["answers"] for k, v in response_data.items()}
    all_results = {k: v["results"] for k, v in response_data.items()}
    all_log_probs = {k: v["log_probs"] for k, v in response_data.items()}
    all_prompts = {k: dl.questions[k] for k, v in response_data.items()}

    # Potentially filter data
    data = maybe_filter_data(
        all_responses, 
        all_answers, 
        all_results, 
        None, 
        all_log_probs, 
        None, 
        all_prompts, 
        max_n=cfg.plot.max_n, 
        subsample_size=cfg.plot.subsample_size
    )
    all_responses, all_answers, all_results, all_reward_scores, all_log_probs, all_features, all_prompts = data

    acc_at_n = defaultdict(list)
    for prompt_idx in all_results:
        n_to_correct = defaultdict(list)
        for _ in range(10):
            results_log_probs = list(zip(all_results[prompt_idx], all_log_probs[prompt_idx]))
            np.random.shuffle(results_log_probs)
            for n in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:
                results_log_probs_up_to_n = results_log_probs[:n]
                # find if result with max log prob is correct
                max_log_prob_result = max(results_log_probs_up_to_n, key=lambda x: x[1])
                if max_log_prob_result[0]:
                    n_to_correct[n].append(1)
                else:
                    n_to_correct[n].append(0)

        for n in n_to_correct:
            acc_at_n[n].append(np.mean(n_to_correct[n]))

    # average over questions
    for n in acc_at_n:
        acc_at_n[n] = np.mean(acc_at_n[n])

    # plot acc at n
    sns.set_theme(style="whitegrid")
    plt.plot(
        acc_at_n.keys(), 
        acc_at_n.values(),
        marker=MARKER,
        markersize=MARKERSIZE,
        linewidth=LINEWIDTH,
        markeredgewidth=MARKEREDGEWIDTH,
        markeredgecolor=FACE_COLOR,
    )
    plt.title("MATH: sharpening")
    plt.xscale('log', base=2)
    plt.xlabel('N')
    plt.ylabel('Accuracy')
    x_ticks = [2**i for i in range(int(np.log2(max(acc_at_n.keys()))) + 1)]
    x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(acc_at_n.keys()))) + 1)]
    plt.xticks(x_ticks, x_tick_labels)
    plt.savefig("figures/sharpening_test.pdf")

if __name__ == "__main__":
    main()