
import torch
from transformers import AutoModelForSequenceClassification

from inference_rlhf.code.models.armo_rm_dt import LlamaForDecisionTreeRewardModel
from inference_rlhf.code.models.armo_rm import LlamaForRewardModelWithGating

class BaseRewardModel: 
    def __init__(self, cfg, **kwargs): 
        if cfg.reward.model == "RLHFlow/Decision-Tree-Reward-Llama-3.1-8B":
            self.network = LlamaForDecisionTreeRewardModel.from_pretrained(cfg.reward.model, trust_remote_code=True, **kwargs)
        elif cfg.reward.model == "RLHFlow/ArmoRM-Llama3-8B-v0.1":
            self.network = LlamaForRewardModelWithGating.from_pretrained(cfg.reward.model, trust_remote_code=True, **kwargs)
        else:
            self.network = AutoModelForSequenceClassification.from_pretrained(cfg.reward.model, trust_remote_code=True, **kwargs)


    @torch.no_grad()
    def get_reward(self, batch, return_hidden_states: bool = False): 
        output = self.network(**batch)
        scores = output.logits.squeeze(-1).cpu().tolist()
        if return_hidden_states:
            return scores, output.hidden_state
        return scores
