"""
示例：使用多任务 Libero 数据训练 World Model

这个脚本展示了如何使用多个 Libero task suite 的数据训练 World Model。
"""

from config import get_libero_config
from dynamics_workspace import DynamicsWorkspace
import argparse
from config import Config
import os, json

def main():
    # 配置多任务 Libero 数据
    # 可以指定特定的 task suites，或者使用 None 加载所有默认的 task suites
    """主函数"""
    parser = argparse.ArgumentParser(description='训练 World Model')
    
    # 数据参数
    parser.add_argument('--output_dir', type=str, default='./exp_output')
    parser.add_argument('--exp_name', type=str, default='world_model')

    parser.add_argument('--task', type=str, default='libero_spatial',
                        help='任务名称')
    parser.add_argument('--threshold', type=float, default=None)
    parser.add_argument('--framestack', type=int, default=1)
    parser.add_argument('--minimum_decay_steps', type=int, default=2)
    parser.add_argument('--port', type=int, default=8888)
    args = parser.parse_args()
    task = args.task
    
    # for task in tasks:
    dataset_name = task + '_no_noops'
    config_path = os.path.join(args.output_dir, dataset_name, args.exp_name, "config.json")
    with open(config_path, "r") as f:
        config = json.load(f)
    config = Config.from_dict(config)
    config.log.use_wandb = False
    config.model.framestack = args.framestack
    workspace = DynamicsWorkspace(config, train=False)
    
    # 开始评估
    workspace.load_snapshot()
    # workspace.collect_transitions_with_different_ds(
    #     task_suite_name=task, 
    #     num_episodes=5, 
    #     downsample_rates=[1,2], 
    #     dataset_name=dataset_name
    # )
    # workspace.eval_world_model_on_transitions(dataset_name, downsample_rates=[1,2])
    workspace.start_server(
        host="0.0.0.0",
        port=args.port,
        downsample_rates=[1,2],
        minimum_decay_steps=args.minimum_decay_steps,
        threshold=args.threshold,
    )

if __name__ == "__main__":
    main()

