import importlib

import numpy as np 
from typing import Any, Dict, List, Optional, Union

from inference_rlhf.code.query_builders.llama import LlamaQueryBuilder
from inference_rlhf.code.collators.policy import PolicyCollator


class LlamaCollator(PolicyCollator): 
    def __init__(self, cfg):
        super().__init__(cfg)
        tm = importlib.import_module(f"inference_rlhf.code.tasks.{cfg.task.name}", package='code')

        self.query_builder = LlamaQueryBuilder(
            cfg=cfg.policy,
            task_desc=cfg.task.TASK_DESC,
            shots=cfg.shots,  
            question_format=tm.QUESTION_FORMAT,
            answer_format=tm.ANSWER_FORMAT,
            sep=tm.SEP,
        )