from .base import BaseTask
from .oup import OUPTask
from .gaussian import GaussianTask
from .sir import SIRTask
from .cyro_em import CryoEMTask

_TASK_REGISTRY = {
    "oup": OUPTask,
    "gaussian": GaussianTask, 
    "sir": SIRTask,           
    "cryo_em": CryoEMTask,  
}

def get_task(cfg) -> BaseTask:
    task_type = cfg.model_type.lower()
    if task_type not in _TASK_REGISTRY:
        raise ValueError(f"Unknown task type: {task_type}. Available: {list(_TASK_REGISTRY.keys())}")
    return _TASK_REGISTRY[task_type](cfg)