import os
import pickle
import logging

import hydra
from omegaconf import OmegaConf, DictConfig
from tqdm import tqdm

from inference_rlhf.code.helpers.io import json_dump

log = logging.getLogger(__name__)

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    log.info(OmegaConf.to_yaml(cfg))

    # Load responses
    log.info(f"Loading {cfg.policy.anonymous_model_name} responses ...")
    with open(os.path.join(cfg.root, "anonymous_generations", cfg.task.name, f"{cfg.policy.anonymous_model_name}_responses.pkl"), "rb") as f:
        all_responses = pickle.load(f)
    log.info(f"Loaded {len(all_responses)} responses for {cfg.policy.anonymous_model_name}")

    for prompt_idx, data in tqdm(all_responses.items(), desc=f"Extracting {cfg.policy.anonymous_model_name} responses into json files ..."):
        save_root = os.path.join(cfg.io.save_root, "data", cfg.task.name, cfg.policy.name, "generations")
        if not os.path.exists(save_root):
            os.makedirs(save_root)
        save_path = os.path.join(save_root, f"{cfg.policy.anonymous_model_name}--shots-1--prompt-idx-{prompt_idx}-generations.json")

        # Skip if the file already exists
        if os.path.exists(save_path):
            log.info(f"Skipping {save_path} because it already exists")
            continue
        
        responses = []
        for response, answer, strict_answer, result, strict_result, avg_logprob, sum_logprob, ground_truth in zip(data["responses"], data["answers"], data["strict_answers"], data["results"], data["strict_results"], data['avg_logprobs'], data['sum_logprobs'], data['ground_truth_answers']):
            responses.append(
                {
                    "prompt_idx": prompt_idx,
                    "response": response,
                    "correct": bool(result),
                    "strict_correct": bool(strict_result),
                    "extracted_answer": answer,
                    "strict_extracted_answer": strict_answer,
                    "avg_logprob": avg_logprob,
                    "sum_logprob": sum_logprob,
                    "ground_truth": ground_truth
                }
            )
        
        # save as json
        json_dump(responses, save_path)
        log.info(f"Succesfully saved {save_path}.")

if __name__ == "__main__":
    main()