import random
from collections import OrderedDict, defaultdict
from typing import Dict, List
import logging

import numpy as np
from tqdm import tqdm

from inference_rlhf.code.coreset.coreset import Coreset
from inference_rlhf.code.tasks.math import judge_correct
from inference_rlhf.code.helpers.utils import timing

log = logging.getLogger(__name__)

class UniqueRandomCoreset(Coreset):
    def __init__(self, cfg):
        super().__init__(cfg)

        self.num_avg_trials = 20 # TODO: read from cfg
    
    @timing
    def coreset(self, model_answers: Dict[int, List[int]]) -> Dict[int, List[int]]:
        """
        Compute the coreset for the given answers. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        Returns a dictionary with the same format, but with the answers shuffled and de-duplicated.
        """
        coreset = {}
        for example_id, answers in model_answers.items():
            # NOTE: fromkeys makes sure the order of appearance is preserved
            answers_copy = answers.copy()
            random.shuffle(answers_copy)
            coreset[example_id] = list(OrderedDict.fromkeys(answers_copy))
        return coreset
    
    @timing
    def pass_at_k(self, model_answers: Dict[int, List[int]], gt_answers: List[int]) -> Dict[int, float]:
        """
        Compute the pass@k for the given answers using a random coreset. Expects answers in the following format:
        {
            example_id: [answer1, answer2, answer3],
            example_id: [answer1, answer2, answer3],
            ...
        }
        """
        single_trial_pass_at_k = []
        for _ in tqdm(range(self.num_avg_trials), desc="Computing pass@k ..."):
            k_to_coreset_results = defaultdict(list)
            coreset_answers = self.coreset(model_answers)
            for k in self.ks:
                log.info(f"Computing pass@k for k={k} ...")
                for example_id, answers in coreset_answers.items(): 
                    gt = gt_answers[example_id]
                    k_to_coreset_results[k].append(float(any(judge_correct(gt, a) for a in answers[:k])))
            single_trial_pass_at_k.append({k: np.mean(v) for k, v in k_to_coreset_results.items()})
        
        return {k: np.mean([d[k] for d in single_trial_pass_at_k]) for k in single_trial_pass_at_k[0].keys()}
