from config import get_libero_config, get_hdf5_config
from dynamics_workspace import DynamicsWorkspace
import argparse
from config import Config
import os

def main():
    parser = argparse.ArgumentParser(description='训练 World Model')
    
    parser.add_argument('--data_dir', type=str, default='./libero_rlds_data',
                        help='数据目录')
    parser.add_argument('--num_demos', type=int, default=None,
                        help='使用的 demo 数量')
    parser.add_argument('--single_arm', action='store_true', help='是否使用单臂数据')
    parser.add_argument('--exp_name', type=str, default='world_model',
                        help='实验名称')
    parser.add_argument('--output_dir', type=str, default='./exp_output',
                        help='输出目录')
    
    parser.add_argument('--batch_size', type=int, default=128,
                        help='批大小')
    parser.add_argument('--seq_len', type=int, default=10,
                        help='序列长度')
    parser.add_argument('--framestack', type=int, default=1,
                        help='帧堆叠数量')
    parser.add_argument('--num_train_steps', type=int, default=30000,
                        help='训练步数')
    parser.add_argument('--learning_rate', type=float, default=1e-4,
                        help='学习率')
    
    parser.add_argument('--env', type=str, default='libero', help='环境类型')
    parser.add_argument('--task', type=str, default='libero_spatial_no_noops', help='任务类型')
    parser.add_argument('--use_rgb', action='store_true', help='是否使用 RGB')

    parser.add_argument('--seed', type=int, default=42,
                        help='随机种子')
    parser.add_argument('--device', type=str, default='cuda:0',
                        help='设备')
    parser.add_argument('--no_wandb', action='store_true',
                        help='不使用 wandb')
    parser.add_argument('--mode', type=str, default='online', help='wandb 模式')
    args = parser.parse_args()
    
    if args.env == 'libero':
        data_dir = args.data_dir
        config = get_libero_config(data_dir=data_dir, dataset_names=[args.task], use_rgb=args.use_rgb)
    elif args.env == 'hdf5':
        data_dir = os.path.join(args.data_dir, args.task)
        config = get_hdf5_config(data_dir=data_dir)
    else:
        config = Config()

    config.data.num_demos = args.num_demos
    config.log.exp_name = args.exp_name
    config.log.output_dir = os.path.join(args.output_dir, args.task)
    config.training.batch_size = args.batch_size
    config.training.seq_len = args.seq_len
    config.training.num_train_steps = args.num_train_steps
    config.training.learning_rate = args.learning_rate
    config.training.seed = args.seed
    config.training.device = args.device
    config.log.use_wandb = not args.no_wandb
    config.log.wandb_mode = args.mode
    config.model.use_pixels = args.use_rgb
    config.model.framestack = args.framestack
    config.data.loader_kwargs["use_rgb"] = args.use_rgb
    config.data.loader_kwargs["single_arm"] = args.single_arm
    workspace = DynamicsWorkspace(config)
    
    workspace.train()


if __name__ == "__main__":
    main()

