# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0


"""
datamodules 注册与按需导入

原始实现会在 import byprot.datamodules 时递归导入整个目录，
这会强制依赖 openfold 等与当前 MS De Novo 任务无关的包，
导致简单训练脚本 (train_ms_denovo.py) 出现 ModuleNotFoundError。

当前任务只需要我们手写的 `novobench_ms` datamodule，
因此这里不再自动 import 目录下所有模块，而是：
- 提供一个 DATAMODULE_REGISTRY 和 register_datamodule 装饰器，供其他代码使用；
- 需要哪个 datamodule，就显式 `from byprot.datamodules.novobench_ms import ...`。
"""

from typing import Dict, Type

DATAMODULE_REGISTRY: Dict[str, Type] = {}


def register_datamodule(name: str):
    def decorator(cls):
        DATAMODULE_REGISTRY[name] = cls
        return cls

    return decorator

