
import torch 

from inference_rlhf.code.rewards.base import BaseRewardModel

class ArmoRMRewardModel(BaseRewardModel):
    def __init__(self, cfg):
        super().__init__(cfg, torch_dtype=torch.bfloat16)