import numpy as np
import sglang as sgl
from typing import Callable, List
from transformers import PreTrainedTokenizer
from sglang.lang.interpreter import ProgramState

from search.decoding import simple_decoding
from search.utils import aggregate_score


@sgl.function
def best_of_n(s: ProgramState, prompt, sample, verifier, tokenizer: PreTrainedTokenizer, args):
    sampling_params = dict(temperature=args.temperature, max_tokens=args.max_tokens)
    
    outputs = simple_decoding.run_batch([dict(prompt=prompt, **sampling_params) for _ in range(args.n_samples)])
    answers = [output["result"] for output in outputs]
    
    if verifier is None:
        scores = [[1.0]] * len(answers)
        agg_scores = [score[0] for score in scores]
        best_output = answers[0]
    else:
        scores = verifier(sample, answers)
        agg_scores = aggregate_score(scores, args.agg_strategy)
        best_output = answers[np.argmax(agg_scores)]
    
    s["prediction"] = best_output
    s["outputs"] = answers
    s["scores"] = scores
    s["agg_scores"] = agg_scores
    s["metric"] = {"rewards_mean": np.mean([np.mean(sc) for sc in scores])}
