import os
import pickle
import logging
import json
import hydra
from omegaconf import OmegaConf, DictConfig
from tqdm import tqdm

from inference_rlhf.code.helpers.io import json_dump

log = logging.getLogger(__name__)

MODEL_TYPE_TO_POLICY = {
    "llama32medium": "llama-3-3b",
    "mistral7b": "mistral-7b"
}

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    log.info(OmegaConf.to_yaml(cfg))

    for model_type in cfg.model_types:
        # Load responses
        log.info(f"Loading {model_type} responses ...")
        with open(os.path.join(cfg.root, "amlt", f"{model_type}_responses.pkl"), "rb") as f:
            all_responses = pickle.load(f)
        log.info(f"Loaded {len(all_responses)} responses for {model_type}")

        save_root = os.path.join(cfg.io.save_root, "data", cfg.task.name, MODEL_TYPE_TO_POLICY[model_type], "generations")
        # loop over all files in the save_root
        for file in tqdm(os.listdir(save_root), desc=f"Backfilling {model_type} responses"):
            if file.endswith(".json") and "armo-rm" in file:
                with open(os.path.join(save_root, file), "r") as f:
                    json_data = json.load(f)

                # backfill logprobs
                for idx, json_response in enumerate(json_data):
                    json_data[idx]["avg_logprobs"] = all_responses[json_response["prompt_idx"]]["avg_logprobs"][idx]
                    json_data[idx]["sum_logprobs"] = all_responses[json_response["prompt_idx"]]["sum_logprobs"][idx]

                    assert json_data[idx]["response"] == all_responses[json_response["prompt_idx"]]["responses"][idx]
                    assert json_data[idx]["correct"] == all_responses[json_response["prompt_idx"]]["results"][idx]
                    assert json_data[idx]["extracted_answer"] == all_responses[json_response["prompt_idx"]]["answers"][idx]

                # save as json
                json_dump(json_data, os.path.join(save_root, file))
                log.info(f"Succesfully saved {os.path.join(save_root, file)}.")

if __name__ == "__main__":
    main()