
from inference_rlhf.code.query_builders.base import BaseQueryBuilder
from inference_rlhf.code.collators.base import BaseCollator
from inference_rlhf.code.rewards.base import BaseRewardModel
from inference_rlhf.code.policy.base import BasePolicy
from inference_rlhf.code.tasks.base import BaseDataLoader
import logging

log = logging.getLogger(__name__)

def construct_collator(collator_name: str, *args, **kwargs) -> BaseCollator:
    if collator_name.startswith("armo-rm"):
        from inference_rlhf.code.collators.armo_rm import ArmoRMCollator
        collator = ArmoRMCollator
    elif collator_name.startswith("llama"):
        from inference_rlhf.code.collators.llama import LlamaCollator
        collator = LlamaCollator
    elif collator_name.startswith("mistral"):
        from inference_rlhf.code.collators.mistral import MistralCollator
        collator = MistralCollator
    elif collator_name.startswith("phi"):
        from inference_rlhf.code.collators.phi import PhiCollator
        collator = PhiCollator
    elif collator_name.startswith("qwen"):
        from inference_rlhf.code.collators.qwen import QwenCollator
        collator = QwenCollator
    else:
        raise ValueError(f"Collator {collator_name} not found")
    
    return collator(*args, **kwargs)

def construct_reward_model(reward_model_name: str, *args, **kwargs) -> BaseRewardModel:
    if reward_model_name.startswith("armo-rm"):
        from inference_rlhf.code.rewards.armo_rm import ArmoRMRewardModel
        rm = ArmoRMRewardModel
    else:
        raise ValueError(f"Reward model {reward_model_name} not found")
    
    return rm(*args, **kwargs)

def construct_policy_model(policy_name: str, *args, **kwargs) -> BasePolicy:
    if policy_name.startswith("llama"):
        from inference_rlhf.code.policy.llama import LlamaPolicy
        policy = LlamaPolicy
    elif policy_name.startswith("mistral"):
        from inference_rlhf.code.policy.mistral import MistralPolicy
        policy = MistralPolicy
    elif policy_name.startswith("phi"):
        from inference_rlhf.code.policy.phi import PhiPolicy
        policy = PhiPolicy
    elif policy_name.startswith("qwen"):
        from inference_rlhf.code.policy.qwen import QwenPolicy
        policy = QwenPolicy
    else:
        raise ValueError(f"Policy model {policy_name} not found")
    
    return policy(*args, **kwargs)

def dataloader_factory(task_name: str, *args, **kwargs) -> BaseDataLoader:
    """
    Factory function to construct a DataLoader for a given task.

    Args:
        task_name (str): The name of the task to construct a DataLoader for.

    Returns:
        BaseDataLoader: The constructed DataLoader.
    """
    if task_name.startswith("math"):
        from inference_rlhf.code.tasks.math import MATHDataLoader
        dl_cls = MATHDataLoader
    elif task_name.startswith("gsm8k"):
        from inference_rlhf.code.tasks.gsm8k import GSM8KDataLoader
        dl_cls = GSM8KDataLoader
    elif task_name.startswith("aime_2025"):
        from inference_rlhf.code.tasks.aime_2025 import AIME2025DataLoader
        dl_cls = AIME2025DataLoader
    elif task_name.startswith("game24"):
        from inference_rlhf.code.tasks.game24 import Game24DataLoader
        dl_cls = Game24DataLoader
    elif task_name.startswith("mbpp"):
        from inference_rlhf.code.tasks.mbpp import MBPPDataLoader
        dl_cls = MBPPDataLoader
    else:
        raise ValueError(f"DataLoader for {task_name} not found")
    
    log.info(f"Loading {task_name} dataset ...")
    dl = dl_cls(*args, **kwargs)
    log.info(f"Done loading {task_name} dataset.")
    return dl