import os
import pickle
from collections import defaultdict
from typing import Dict, List, Tuple
import logging
from tqdm import tqdm

import hydra
from omegaconf import OmegaConf, DictConfig
from importlib import import_module
import numpy as np

from inference_rlhf.code.helpers.utils import timing

log = logging.getLogger(__name__)

def extract_logprobs(logprobs: List[List[float]]) -> Tuple[List[float], List[float]]:
    """
    Extract average and sum of logprobs.
    """
    avg_logprobs = []
    sum_logprobs = []
    for logprob in logprobs:
        np_logprobs = np.array(logprob)
        avg_logprobs.append(np_logprobs.mean())
        sum_logprobs.append(np_logprobs.sum())
    return avg_logprobs, sum_logprobs

@timing
def preprocess_all_responses(cfg: DictConfig, anonymous_model_name: str, dl) -> Dict[int, List[str]]:
    """
    Load all responses from the anonymous_generations directory, and return a list of dictionaries as follows:
    {
        example_id: {
            'responses': List[str],
            'answers': List[int],
            'results': List[int],
            'avg_logprobs': List[float],
            'sum_logprobs': List[float]
        }
    }
    """
    all_responses = defaultdict(lambda: defaultdict(list))
    relevant_files = [file for file in os.listdir(os.path.join(cfg.root, "anonymous_generations")) if file.startswith(f"{cfg.task.name}_{anonymous_model_name}")]
    for file in tqdm(relevant_files, total=len(relevant_files)):
        # load with pickle
        with open(os.path.join(cfg.root, "anonymous_generations", file), "rb") as f:
            data = pickle.load(f)

        for problem in data:
            all_responses[problem["example_id"]]["responses"].extend(problem["responses"])
            answers = extract_answers(problem["responses"], cfg.policy.answer_patterns, cfg.task.name, strict=False)
            all_responses[problem["example_id"]]["answers"].extend(answers)
            strict_answers = extract_answers(problem["responses"], cfg.policy.answer_patterns, cfg.task.name, strict=True)
            all_responses[problem["example_id"]]["strict_answers"].extend(strict_answers)
            results = extract_results([dl.answers[problem["example_id"]]] * len(answers), answers, cfg.task.name)
            all_responses[problem["example_id"]]["results"].extend(results)
            results_strict = extract_results([dl.answers[problem["example_id"]]] * len(strict_answers), strict_answers, cfg.task.name)
            all_responses[problem["example_id"]]["strict_results"].extend(results_strict)
            avg_logprobs, sum_logprobs = extract_logprobs(problem["logprobs"])
            all_responses[problem["example_id"]]["avg_logprobs"].extend(avg_logprobs)
            all_responses[problem["example_id"]]["sum_logprobs"].extend(sum_logprobs)
            all_responses[problem["example_id"]]["ground_truth_answers"].extend([dl.answers[problem["example_id"]]] * len(answers))
            
    return all_responses

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    log.info(OmegaConf.to_yaml(cfg))

    dir_to_save = os.path.join(cfg.root, "anonymous_generations", cfg.task.name)
    os.makedirs(dir_to_save, exist_ok=True)
    if os.path.exists(os.path.join(dir_to_save, f"{cfg.policy.anonymous_model_name}_responses.pkl")):
        log.info(f"Skipping {cfg.policy.anonymous_model_name} because it already exists")
        return

    # Load dataset
    log.info(f"Loading {cfg.task.name} dataset ...")
    data_module = import_module(f"inference_rlhf.code.tasks.{cfg.task.name}",  package='inference_rlhf.code')
    dl = data_module.DataLoader(cfg)
    log.info(f"Done loading {cfg.task.name} dataset.")

    # Merge responses
    all_responses = preprocess_all_responses(cfg, cfg.policy.anonymous_model_name, dl)
    log.info(f"Loaded {len(all_responses)} responses for {cfg.policy.anonymous_model_name}")

    # Convert defaultdict to dict
    all_responses = dict(all_responses)

    # Save responses
    log.info(f"Saving {len(all_responses)} responses for {cfg.policy.anonymous_model_name}")
    with open(os.path.join(dir_to_save, f"{cfg.policy.anonymous_model_name}_responses.pkl"), "wb") as f:
        pickle.dump(all_responses, f)
    log.info(f"Saved {len(all_responses)} responses for {cfg.policy.anonymous_model_name}")

if __name__ == "__main__":
    main()