from collections import defaultdict, Counter
from typing import Dict, List
import logging

import numpy as np
from tqdm import tqdm

from inference_rlhf.code.coreset.coreset import Coreset
from inference_rlhf.code.tasks.math import judge_correct
from inference_rlhf.code.helpers.utils import timing

log = logging.getLogger(__name__)

class UniqueFrequencyCoreset(Coreset):
    def __init__(self, cfg):
        super().__init__(cfg)

    
    @timing
    def coreset(self, model_answers: Dict[int, List[int]]) -> Dict[int, List[int]]:
        """
        Compute the coreset for the given answers. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        Returns a dictionary with the same format, but with the answers de-duplicated and sorted by frequency.
        """
        coreset = {}
        for example_id, answers in model_answers.items():
            answer_counts = Counter(answers)
            sorted_answers = list(sorted(set(answers), key=lambda x: answer_counts[x], reverse=True))
            coreset[example_id] = sorted_answers
        return coreset
    
    @timing
    def pass_at_k(self, model_answers: Dict[int, List[int]], gt_answers: Dict[int, int]) -> Dict[int, float]:
        """
        Compute the pass@k for the given answers. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        Returns a dictionary with the estimated pass@k for each k.
        """
        k_to_coreset_results = defaultdict(list)
        coreset_answers = self.coreset(model_answers)
        for k in tqdm(self.ks, desc="Computing pass@k ..."):
            for example_id, answers in coreset_answers.items(): 
                gt = gt_answers[example_id]
                k_to_coreset_results[k].append(float(any(judge_correct(gt, a) for a in answers[:k])))
                
        return {k: np.mean(v) for k, v in k_to_coreset_results.items()}
