import importlib
from .base_task import BaseTask

TASK_CATEGORIES = {
    "PlantVillage": [
        "Apple", "Corn", "Grape", "Potato",
    ],
    "CUB": [
        "hummingbird", "albatross", "bunting", "jay", "cuckoo", "cormorant", # 3 class
        "swallow", "blackbird", "auklet", "grosbeak", "oriole", "grebe", # 4 class
    ],

    "VQA": [
        "rsvqa", "drivingvqa", 
        "MRI", "CT", "X-Ray", # SLAKE
    ],
    "Video": ["driveact"],
    "VideoVQA": ["vanebench"],
    "MoleculeClassification": ["pampa", "hia", "pgp", "bioavailability", "bbb",
        "cyp2c19", "cyp3a4", "cyp2d6", "cyp1a2", "cyp2c9", "dili", "herg",
        "carcinogen", "ames", "sarscov2vitro", "sarscov23clpro",
        "cyp2c9substrate", "cyp2d6substrate", "cyp3a4substrate"
    ]
}


def get_task(task_name):
    for class_name, tasks in TASK_CATEGORIES.items():
        if task_name in tasks:
            try:
                module = importlib.import_module(f".{class_name.lower()}", package=__package__)
                return getattr(module, class_name)
            except ModuleNotFoundError:
                raise ValueError(f"Module for task '{task_name}' could not be found.")
    raise ValueError(f"{task_name} is not a recognized task")
