"""
示例：使用 MLP World Model 启动推理服务器

这个脚本展示了如何加载训练好的 MLP World Model 并启动推理服务器。
"""

from config import get_libero_config
from mlp_workspace import DynamicsWorkspace
import argparse
from config import Config
import os, json

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='启动 MLP World Model 推理服务器')
    
    # 数据参数
    parser.add_argument('--output_dir', type=str, default='./exp_output')
    parser.add_argument('--exp_name', type=str, default='mlp_large')

    parser.add_argument('--task', type=str, default='libero_spatial',
                        help='任务名称')
    parser.add_argument('--threshold', type=float, default=None)
    parser.add_argument('--framestack', type=int, default=1)
    parser.add_argument('--minimum_decay_steps', type=int, default=2)
    parser.add_argument('--port', type=int, default=8889)
    args = parser.parse_args()
    task = args.task
    
    # 加载配置
    dataset_name = task + '_no_noops'
    config_path = os.path.join(args.output_dir, dataset_name, args.exp_name, "config.json")
    with open(config_path, "r") as f:
        config = json.load(f)
    config = Config.from_dict(config)
    config.log.use_wandb = False
    config.model.framestack = args.framestack
    workspace = DynamicsWorkspace(config, train=False)
    
    # 加载模型
    workspace.load_snapshot()
    
    # 启动服务器
    workspace.start_server(
        host="0.0.0.0",
        port=args.port,
        downsample_rates=[1,2],
        minimum_decay_steps=args.minimum_decay_steps,
        threshold=args.threshold,
    )

if __name__ == "__main__":
    main()

