from typing import Dict, Type
from sal.models.prm import (
    PRM,
    MathShepherd,
    RLHFFlow,
    BiasDetectionPRM,
    LoraBiasDetectionPRM,
    OutcomeDetectionPRM,
    UntrainedBiasPRM,
)

PRM_REGISTRY: Dict[str, Type[PRM]] = {
    "PRM_PATH_3": BiasDetectionPRM,
    "PRM_PATH_4": BiasDetectionPRM,
    "PRM_PATH_5": BiasDetectionPRM,
    "PRM_PATH_6": OutcomeDetectionPRM,
    "PRM_PATH_7": OutcomeDetectionPRM,
    "PRM_PATH_8": BiasDetectionPRM,
    "PRM_PATH_9": BiasDetectionPRM,
    "PRM_PATH_10": BiasDetectionPRM,
    "PRM_PATH_12": UntrainedBiasPRM,
}

def register_prm(model_id: str, prm_class: Type[PRM]):
    """Register a new PRM class."""
    PRM_REGISTRY[model_id] = prm_class

def get_prm_class(model_id: str) -> Type[PRM]:
    """Get PRM class for a model ID."""
    if model_id not in PRM_REGISTRY:
        raise ValueError(f"No PRM class found for model {model_id}")
    return PRM_REGISTRY[model_id] 