"""
训练 MLP World Model

这个脚本用于训练基于MLP的简单世界模型，
它根据当前state和单步action预测下一步state。
"""

from config import get_libero_config, get_hdf5_config
from mlp_workspace import DynamicsWorkspace
import argparse
from config import Config
import os

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='训练 MLP World Model')
    
    # 数据参数
    parser.add_argument('--data_dir', type=str, default='./libero_rlds_data',
                        help='数据目录')
    parser.add_argument('--num_demos', type=int, default=None,
                        help='使用的 demo 数量')
    parser.add_argument('--single_arm', action='store_true', help='是否使用单臂数据')
    
    # 实验参数
    parser.add_argument('--exp_name', type=str, default='mlp_world_model',
                        help='实验名称')
    parser.add_argument('--output_dir', type=str, default='./exp_output',
                        help='输出目录')
    
    # 训练参数
    parser.add_argument('--batch_size', type=int, default=256,
                        help='批大小（MLP模型可以使用更大的batch）')
    parser.add_argument('--num_train_steps', type=int, default=50000,
                        help='训练步数')
    parser.add_argument('--learning_rate', type=float, default=1e-4,
                        help='学习率')
    parser.add_argument('--weight_decay', type=float, default=1e-5,
                        help='权重衰减')
    parser.add_argument('--grad_clip', type=float, default=100.0,
                        help='梯度裁剪阈值')
    
    # 模型参数
    parser.add_argument('--hidden_dims', type=int, nargs='+', default=[256, 256, 256],
                        help='MLP隐藏层维度列表，例如: --hidden_dims 256 256 256')
    parser.add_argument('--dropout', type=float, default=0.1,
                        help='Dropout概率')
    parser.add_argument('--use_var', action='store_true',
                        help='预测方差（默认不使用）')
    
    # 环境参数
    parser.add_argument('--env', type=str, default='libero', help='环境类型')
    parser.add_argument('--task', type=str, default='libero_spatial_no_noops', help='任务类型')
    
    # 其他参数
    parser.add_argument('--seed', type=int, default=42,
                        help='随机种子')
    parser.add_argument('--device', type=str, default='cuda:0',
                        help='设备')
    parser.add_argument('--no_wandb', action='store_true',
                        help='不使用 wandb')
    parser.add_argument('--mode', type=str, default='online', help='wandb 模式')
    
    # 评估参数
    parser.add_argument('--eval_every_steps', type=int, default=1000,
                        help='评估间隔')
    parser.add_argument('--save_every_steps', type=int, default=5000,
                        help='保存间隔')
    parser.add_argument('--log_every_steps', type=int, default=100,
                        help='日志记录间隔')
    
    args = parser.parse_args()
    
    # 创建配置
    if args.env == 'libero':
        data_dir = args.data_dir
        # MLP模型不使用RGB，只用低维观测
        config = get_libero_config(data_dir=data_dir, dataset_names=[args.task], use_rgb=False)
    elif args.env == 'hdf5':
        data_dir = os.path.join(args.data_dir, args.task)
        config = get_hdf5_config(data_dir=data_dir)
    else:
        config = Config()

    # 更新数据配置
    config.data.num_demos = args.num_demos
    config.data.loader_kwargs["use_rgb"] = False  # MLP模型不使用RGB
    config.data.loader_kwargs["single_arm"] = args.single_arm
    
    # 更新实验配置
    config.log.exp_name = args.exp_name
    config.log.output_dir = os.path.join(args.output_dir, args.task)
    config.log.use_wandb = not args.no_wandb
    config.log.wandb_mode = args.mode
    config.log.wandb_project = "mlp_world_model"
    
    # 更新训练配置
    config.training.batch_size = args.batch_size
    config.training.num_train_steps = args.num_train_steps
    config.training.learning_rate = args.learning_rate
    config.training.weight_decay = args.weight_decay
    config.training.grad_clip = args.grad_clip
    config.training.seed = args.seed
    config.training.device = args.device
    config.training.eval_every_steps = args.eval_every_steps
    config.training.save_every_steps = args.save_every_steps
    config.training.log_every_steps = args.log_every_steps
    
    # 更新模型配置
    config.model.hidden_dims = args.hidden_dims
    config.model.dropout = args.dropout
    config.model.use_symlog = True
    config.model.use_var = args.use_var  # 默认False，除非指定--use_var
    config.model.use_residual = False
    config.model.use_pixels = False  # MLP模型不使用像素输入
    config.model.use_var = False
    config.model.framestack = 1  # MLP模型不使用framestack
    
    # 打印配置摘要
    print("\n" + "=" * 60)
    print("MLP World Model 训练配置")
    print("=" * 60)
    print(f"任务: {args.task}")
    print(f"实验名称: {args.exp_name}")
    print(f"批大小: {config.training.batch_size}")
    print(f"训练步数: {config.training.num_train_steps}")
    print(f"学习率: {config.training.learning_rate}")
    print(f"MLP隐藏层: {config.model.hidden_dims}")
    print(f"Dropout: {config.model.dropout}")
    print(f"使用symlog: {config.model.use_symlog}")
    print(f"预测方差: {config.model.use_var}")
    print(f"残差预测: {config.model.use_residual}")
    print(f"设备: {config.training.device}")
    print("=" * 60 + "\n")
    
    # 创建训练工作空间
    workspace = DynamicsWorkspace(config)
    
    # 开始训练
    workspace.train()


if __name__ == "__main__":
    main()

