from .skincon_datamodules import SkinConDataModule, CAVDataModule, RectifiedDataModule
from .CUB_datamodules import CUBDatamodules
from .animal_datamodules import AwA2DataModule
from .animal_datamodules_selected import AwA2DataModule_selected
def construct_data_module(config):
    if config["dataset"] == "skincon":
        data_module = SkinConDataModule(
            data_dir=config["data_dir"],
            batch_size=config["batch_size"],
            train_with_c_gt=config["train_with_c_gt"],
            concept_weight=config["concept_weight"],
        )
    elif config["dataset"] == "CUB":
        data_module = CUBDatamodules(
            seed=config["seed"],
            data_dir=config["data_dir"],
            batch_size=config["batch_size"],
            train_with_c_gt=config["train_with_c_gt"],
            concept_weight=config["concept_weight"],
            arch=config["arch"],
        )
    elif config["dataset"] == "AwA2":
        data_module = AwA2DataModule(
            data_dir=config["data_dir"],
            batch_size=config["batch_size"],
            train_with_c_gt=config["train_with_c_gt"],
            concept_weight=config["concept_weight"],
            arch=config["arch"],
        )
    else:
        raise NotImplementedError

    data_module.prepare_data()
    data_module.setup(stage=None)

    return data_module
