import json
import random
from collections import OrderedDict, defaultdict
from typing import Dict, List
import logging

import numpy as np
from tqdm import tqdm

from inference_rlhf.code.coreset.coreset import Coreset
from inference_rlhf.code.tasks.math import judge_correct
from inference_rlhf.code.helpers.utils import timing, rget_json_files_from_dir
from inference_rlhf.code.tasks.math import extract_answer

log = logging.getLogger(__name__)

class LLMDirectCoreset(Coreset):
    def __init__(self, cfg, load_path: str):
        super().__init__(cfg)

        self.load_path = load_path
    
    @timing
    def coreset(self, model_answers: Dict[int, List[int]]) -> Dict[int, List[int]]:
        """
        Compute the coreset for the given answers. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        Returns a dictionary with the same format, but with the answers shuffled and de-duplicated.
        """
        # load all json files
        json_files = rget_json_files_from_dir(self.load_path)

        coreset = defaultdict(list)
        for prompt_idx in model_answers.keys():
            path = [f for f in json_files if f'prompt-idx-{prompt_idx}' in f and 'coreset' in f][0]
            with open(path, 'r') as f:
                data = json.load(f)
            
            for response in data:
                coreset[prompt_idx].append(extract_answer(response['response'], self.cfg.policy.answer_patterns, strict=False))
            
        return coreset
    
    @timing
    def pass_at_k(self, model_answers: Dict[int, List[int]], gt_answers: List[int]) -> Dict[int, float]:
        """
        Compute the pass@k for the given answers using a random coreset. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        """
        k_to_coreset_results = defaultdict(list)
        coreset_answers = self.coreset(model_answers)
        for k in self.ks:
            log.info(f"Computing pass@k for k={k} ...")
            for example_id, answers in coreset_answers.items(): 
                gt = gt_answers[example_id]
                assert gt is not None
                k_to_coreset_results[k].append(float(any(judge_correct(gt, a) for a in answers[:k])))
        return {k: np.mean(v) for k, v in k_to_coreset_results.items()}