import torch
from omegaconf import DictConfig

from inference_rlhf.code.policy.base import BasePolicy

class LlamaPolicy(BasePolicy): 
    def __init__(self, cfg: DictConfig): 
        super().__init__(cfg, torch_dtype=torch.bfloat16)