import importlib

from inference_rlhf.code.query_builders.mistral import MistralQueryBuilder
from inference_rlhf.code.collators.policy import PolicyCollator


class MistralCollator(PolicyCollator): 
    def __init__(self, cfg):
        super().__init__(cfg)
        tm = importlib.import_module(f"inference_rlhf.code.tasks.{cfg.task.name}", package='code')

        self.query_builder = MistralQueryBuilder(
            cfg=cfg.policy,
            task_desc=cfg.task.TASK_DESC,
            shots=cfg.shots,  
            question_format=tm.QUESTION_FORMAT,
            answer_format=tm.ANSWER_FORMAT,
            sep=tm.SEP,
        )