from typing import Callable, List
from transformers import PreTrainedTokenizer

from models.prm import load_prm
from models.generative_verifier import GenerativeVerifier
from models.pref_verifier import PreferenceVerifier
from models.rm_server import RMServerVerifier
from models.generative_rm_server import GenerativeRMServerVerifier


def get_prm_verifier(args):
    is_thinking_model = "r1" in args.model_path.lower()
    prm = load_prm(args.prm_path, prm_batch_size=args.prm_batch_size)

    def scorer(question, outputs):
        questions = [question]
        if is_thinking_model:
            processed_outputs = []
            for i, output in enumerate(outputs):
                if "</think>" in output:
                    processed_outputs.append(output.split("</think>")[-1])
                else:
                    processed_outputs.append(output[-10000:])
        else:
            processed_outputs = outputs
        scores = prm.score(questions, [processed_outputs])[0]
        return scores

    return scorer


def load_verifier(args, tokenizer: PreTrainedTokenizer) -> Callable[[str, List[str]], List[List[float]]]:
    if args.type == "none":
        return None
    elif args.type == "model":
        return get_prm_verifier(args)
    elif args.type == "generative_verifier":
        return GenerativeVerifier(args)
    elif args.type == "pref_verifier":
        return PreferenceVerifier(args, tokenizer)
    elif args.type == "rm_server":
        return RMServerVerifier(args, tokenizer)
    elif args.type == "generative_rm_server":
        return GenerativeRMServerVerifier(args, tokenizer)
    else:
        raise ValueError(f"Invalid verifier type: {args.verifier.type}")
