from collections import defaultdict, Counter
from typing import Dict, List
import logging

import numpy as np
from tqdm import tqdm

from inference_rlhf.code.coreset.coreset import Coreset
from inference_rlhf.code.tasks.math import judge_correct
from inference_rlhf.code.helpers.utils import timing

log = logging.getLogger(__name__)

class RewardModelUniqueCoreset(Coreset):
    def __init__(self, cfg, mode: str = "argmax"):
        super().__init__(cfg)

        self.num_trials = 20
        self.mode = mode # argmax or sample weighted by reward score

    
    @timing
    def coreset(self, model_answers: Dict[int, List[int]], model_reward_scores: Dict[int, List[float]]) -> Dict[int, List[int]]:
        """
        Compute the coreset for the given answers. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        Returns a dictionary with the same format, but with the answers sorted by reward score.
        """
        coreset = {}
        for example_id, answers in model_answers.items():
            answer_to_max_score = {}
            for answer, score in zip(answers, model_reward_scores[example_id]):
                if answer in answer_to_max_score:
                    if score > answer_to_max_score[answer]:
                        answer_to_max_score[answer] = score
                else:
                    answer_to_max_score[answer] = score

            if self.mode == "argmax":
                coreset[example_id] = sorted(answer_to_max_score.keys(), key=lambda a: answer_to_max_score[a], reverse=True)
            elif self.mode == "sample":
                unique_answers = list(answer_to_max_score.keys())
                scores = [answer_to_max_score[a] for a in unique_answers]
                coreset[example_id] = np.random.choice(unique_answers, size=min(max(self.ks), len(unique_answers)), p=np.exp(scores) / np.sum(np.exp(scores)), replace=False)

        return coreset
    
    @timing
    def pass_at_k(self, model_answers: Dict[int, List[int]], model_reward_scores: Dict[int, List[float]], gt_answers: Dict[int, int]) -> Dict[int, float]:
        """
        Compute the pass@k for the given answers. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        Returns a dictionary with the estimated pass@k for each k.
        """
        if self.mode == "sample":
            single_trial_pass_at_k = []
            for _ in range(self.num_trials):
                log.info(f"Trial {_ + 1} of {self.num_trials} ...")
                k_to_coreset_results = defaultdict(list)
                coreset_answers = self.coreset(model_answers, model_reward_scores)
                for k in tqdm(self.ks, desc="Computing pass@k ..."):
                    for example_id, answers in coreset_answers.items(): 
                        gt = gt_answers[example_id]
                        k_to_coreset_results[k].append(float(any(judge_correct(gt, a) for a in answers[:k])))

                single_trial_pass_at_k.append({k: np.mean(v) for k, v in k_to_coreset_results.items()})

            return {k: np.mean([d[k] for d in single_trial_pass_at_k]) for k in single_trial_pass_at_k[0].keys()}

        elif self.mode == "argmax":
            k_to_coreset_results = defaultdict(list)
            coreset_answers = self.coreset(model_answers, model_reward_scores)
            for k in tqdm(self.ks, desc="Computing pass@k ..."):
                for example_id, answers in coreset_answers.items(): 
                    gt = gt_answers[example_id]
                    k_to_coreset_results[k].append(float(any(judge_correct(gt, a) for a in answers[:k])))
                    
            return {k: np.mean(v) for k, v in k_to_coreset_results.items()}
