import argparse

# 导入原始的Transolver模型和我们新设计的自适应版本
from model.Transolver_Structured_Mesh_2D import Model as Transolver2D_Base
# (您可以根据需要，导入Irregular和3D版本)
from model.adaptive_transolver_slice_temporal import IrregularAdaptiveTransolver
# b) (核心修改) 导入我们新设计的自适应模型
from model.adaptive_transolver import StructuredAdaptiveTransolver
from model.adaptive_transolver_pipe import PipeAdaptiveTransolver
from model.adaptive_transolver_pipe_final import TwoStreamCNNRouterPipeTransolver
from model.Transolver_adapt import SelfRoutingAdaptiveTransolver
def get_model(args):
    """
    一个统一的模型工厂，能够创建Transolver基准和我们的自适应版本。
    """
    model_dict = {
        # 注册原始模型
        'Transolver_Structured_Mesh_2D': Transolver2D_Base,
        'PipeAdaptiveTransolver': PipeAdaptiveTransolver,
        # --- 核心修改：注册我们最终的、最先进的自适应模型 ---
        'StructuredAdaptiveTransolver': StructuredAdaptiveTransolver,
        'IrregularAdaptiveTransolver':IrregularAdaptiveTransolver,
        'TwoStreamCNNRouterPipeTransolver': TwoStreamCNNRouterPipeTransolver,
        'SelfRoutingAdaptiveTransolver': SelfRoutingAdaptiveTransolver,
    }
    
    if args.model not in model_dict:
        raise NotImplementedError(f"模型 '{args.model}' 未在 model_dict_adaptive.py 中定义。")

    # .Model 是因为原始 Transolver 的训练脚本是这样包装的
    # 为了保持兼容，我们也这样包装
    class ModelWrapper:
        def __init__(self, ModelClass):
            self.Model = ModelClass
    
    return ModelWrapper(model_dict[args.model])