from collections import defaultdict, Counter
from typing import Dict, List
import logging

import numpy as np
from tqdm import tqdm

from inference_rlhf.code.coreset.coreset import Coreset
from inference_rlhf.code.tasks.math import judge_correct
from inference_rlhf.code.helpers.utils import timing

log = logging.getLogger(__name__)

class LogProbsCoreset(Coreset):
    def __init__(self, cfg, mode: str = "argmax"):
        super().__init__(cfg)

        self.mode = mode # argmax or sample weighted by reward score
        self.num_trials = 20
    
    @timing
    def coreset(self, model_answers: Dict[int, List[int]], model_log_probs: Dict[int, List[float]]) -> Dict[int, List[int]]:
        """
        Compute the coreset for the given answers. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        Returns a dictionary with the same format, but with the answers sorted by reward score.
        """
        coreset = {}
        for example_id, answers in model_answers.items():
            if self.mode == 'argmax':
                # Create a list of (answer, score) pairs and sort by score
                answer_score_pairs = list(zip(answers, model_log_probs[example_id]))
                # Sort pairs by score in descending order and extract just the answers
                sorted_answers = [answer for answer, _ in sorted(answer_score_pairs, key=lambda pair: pair[1], reverse=True)]
            elif self.mode == 'sample':
                # sample max(self.ks) times from softmax of reward scores
                sorted_answers = np.random.choice(answers, size=max(self.ks), p=np.exp(model_log_probs[example_id]) / np.sum(np.exp(model_log_probs[example_id])), replace=False)

            coreset[example_id] = sorted_answers
        return coreset
    
    @timing
    def pass_at_k(self, model_answers: Dict[int, List[int]], model_log_probs: Dict[int, List[float]], gt_answers: Dict[int, int]) -> Dict[int, float]:
        """
        Compute the pass@k for the given answers. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        Returns a dictionary with the estimated pass@k for each k.
        """
        if self.mode == 'sample':
            single_trial_pass_at_k = []
            for _ in range(self.num_trials):
                log.info(f"Trial {_ + 1} of {self.num_trials} ...")
                k_to_coreset_results = defaultdict(list)
                coreset_answers = self.coreset(model_answers, model_log_probs)
                for k in tqdm(self.ks, desc="Computing pass@k ..."):
                    for example_id, answers in coreset_answers.items(): 
                        gt = gt_answers[example_id]
                        k_to_coreset_results[k].append(float(any(judge_correct(gt, a) for a in answers[:k])))

                single_trial_pass_at_k.append({k: np.mean(v) for k, v in k_to_coreset_results.items()})

            return {k: np.mean([d[k] for d in single_trial_pass_at_k]) for k in single_trial_pass_at_k[0].keys()}

        elif self.mode == 'argmax':
            k_to_coreset_results = defaultdict(list)
            coreset_answers = self.coreset(model_answers, model_log_probs)
            for k in tqdm(self.ks, desc="Computing pass@k ..."):
                for example_id, answers in coreset_answers.items(): 
                    gt = gt_answers[example_id]
                    k_to_coreset_results[k].append(float(any(judge_correct(gt, a) for a in answers[:k])))
                    
            return {k: np.mean(v) for k, v in k_to_coreset_results.items()}
