# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0


"""
模型注册与按需导入
原始实现会在 import byprot.models 时递归导入整个 models 目录，
从而强制依赖 openfold 等本任务不需要的包，导致推理脚本报错。

这里保留一个简单的 MODEL_REGISTRY 和 register_model 装饰器，
但不在模块加载时自动 import 全部子模块。
需要某个模型时，直接 from byprot.models.dplm.xxx import ... 即可。
"""

from typing import Dict, Type

MODEL_REGISTRY: Dict[str, Type] = {}


def register_model(name: str):
    def decorator(cls):
        MODEL_REGISTRY[name] = cls
        return cls

    return decorator

