"""
示例：训练 Safe-IQL Policy

这个脚本展示了如何使用 Libero 数据训练 Safe-IQL Scheduling Policy。
"""

from config import get_libero_policy_config, get_hdf5_policy_config
from policy_workspace import PolicyWorkspace
import argparse
from config import PolicyConfig
import os

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='训练 Safe-IQL Policy')
    
    # 数据参数
    parser.add_argument('--data_dir', type=str, default='./libero_rlds_data',
                        help='数据目录')
    parser.add_argument('--num_demos', type=int, default=None,
                        help='使用的 demo 数量')
    parser.add_argument('--single_arm', action='store_true', help='是否使用单臂数据')
    
    # 实验参数
    parser.add_argument('--exp_name', type=str, default='safe_iql',
                        help='实验名称')
    parser.add_argument('--output_dir', type=str, default='./exp_output',
                        help='输出目录')
    
    # Safe-IQL 重要参数
    parser.add_argument('--k_low', type=int, default=1,
                        help='k 值的下限（降采样率的最小值）')
    parser.add_argument('--k_high', type=int, default=2,
                        help='k 值的上限（降采样率的最大值）')
    parser.add_argument('--epsilon', type=float, default=3.0,
                        help='安全阈值，用于判断 transition 是否安全')
    parser.add_argument('--seq_len', type=int, default=10,
                        help='序列长度')
    parser.add_argument('--control_mode', type=str, default='delta',
                        help='控制模式："delta" 或 "abs"')
    # Safe-IQL 训练超参数
    parser.add_argument('--learning_rate', type=float, default=1e-4,
                        help='学习率')
    parser.add_argument('--expectile', type=float, default=0.95,
                        help='Expectile 参数（用于 IQL 的 value function 更新）')
    parser.add_argument('--num_epochs', type=int, default=200,
                        help='训练轮数')
    parser.add_argument('--batch_size', type=int, default=512,
                        help='批次大小')
    parser.add_argument('--gamma', type=float, default=0.9,
                        help='折扣因子')
    parser.add_argument('--tau', type=float, default=0.005,
                        help='软更新系数（用于 target network）')
    
    # 网络架构参数
    parser.add_argument('--hidden_dims', type=int, nargs='+', default=[512, 512],
                        help='MLP 隐藏层维度列表，例如: --hidden_dims 512 512')
    parser.add_argument('--rnn_hidden_size', type=int, default=256,
                        help='RNN 隐藏层大小')
    parser.add_argument('--num_rnn_layers', type=int, default=2,
                        help='RNN 层数')
    
    # Dynamics 模型路径
    parser.add_argument('--dynamics_snapshot_path', type=str, default=None,
                        help='Dynamics 模型快照路径（如果为 None，则使用默认路径）')
    
    # 环境参数
    parser.add_argument('--env', type=str, default='libero', help='环境类型')
    parser.add_argument('--task', type=str, default='libero_spatial_no_noops', 
                        help='任务类型')
    parser.add_argument('--use_rgb', action='store_true', help='是否使用 RGB')
    
    # 其他参数
    parser.add_argument('--seed', type=int, default=42,
                        help='随机种子')
    parser.add_argument('--device', type=str, default='cuda:0',
                        help='设备')
    parser.add_argument('--no_wandb', action='store_true',
                        help='不使用 wandb')
    parser.add_argument('--mode', type=str, default='online', help='wandb 模式')
    
    # 训练流程参数
    parser.add_argument('--eval_only', action='store_true',
                        help='仅评估模式（不训练）')
    parser.add_argument('--data_path', type=str, default=None,
                        help='Safe-IQL 数据路径（如果为 None，则使用默认路径）')
    
    args = parser.parse_args()
    
    # 创建配置
    if args.env == 'libero':
        config = get_libero_policy_config(
            data_dir=args.data_dir,
            dataset_names=[args.task], 
            use_rgb=args.use_rgb,
            # Safe-IQL 重要参数
            k_low=args.k_low,
            k_high=args.k_high,
            epsilon=args.epsilon,
            # Safe-IQL 训练超参数
            learning_rate=args.learning_rate,
            expectile=args.expectile,
            num_epochs=args.num_epochs,
            batch_size=args.batch_size,
            gamma=args.gamma,
            tau=args.tau,
            # 网络架构参数
            hidden_dims=args.hidden_dims,
            rnn_hidden_size=args.rnn_hidden_size,
            num_rnn_layers=args.num_rnn_layers,
            # Dynamics 模型路径
            dynamics_snapshot_path=args.dynamics_snapshot_path,
        )
    elif args.env == 'hdf5':
        config = get_hdf5_policy_config(
            data_dir=os.path.join(args.data_dir, args.task),
            # Safe-IQL 重要参数
            k_low=args.k_low,
            k_high=args.k_high,
            epsilon=args.epsilon,
            # Safe-IQL 训练超参数
            learning_rate=args.learning_rate,
            expectile=args.expectile,
            num_epochs=args.num_epochs,
            batch_size=args.batch_size,
            gamma=args.gamma,
            tau=args.tau,
            # 网络架构参数
            hidden_dims=args.hidden_dims,
            rnn_hidden_size=args.rnn_hidden_size,
            num_rnn_layers=args.num_rnn_layers,
            # Dynamics 模型路径
            dynamics_snapshot_path=args.dynamics_snapshot_path,
        )
    else:
        config = PolicyConfig()
    
    # 更新配置
    config.data.num_demos = args.num_demos
    config.log.exp_name = args.exp_name
    config.log.output_dir = os.path.join(args.output_dir, args.task)
    config.training.seed = args.seed
    config.training.device = args.device
    config.log.use_wandb = not args.no_wandb
    config.log.wandb_mode = args.mode
    config.data.loader_kwargs["use_rgb"] = args.use_rgb
    config.safe_iql.seq_len = args.seq_len
    config.safe_iql.control_mode = args.control_mode
    config.data.loader_kwargs["single_arm"] = args.single_arm
    # 创建训练工作空间
    workspace = PolicyWorkspace(config)
    
    # 训练流程
    if not args.eval_only:
        # 步骤 1: 采样 Safe-IQL 数据
        print("\n=== 步骤 1: 采样 Safe-IQL 数据 ===")
        data_path = workspace._sample_safe_iql_data()
        print(f"数据已保存到: {data_path}")
        
        # 步骤 2: 训练 Safe-IQL Scheduling Policy
        print("\n=== 步骤 2: 训练 Safe-IQL Scheduling Policy ===")
        workspace._train_safe_iql_scheduling_policy(data_path=data_path, eval=False)
    else:
        # 仅评估模式
        print("\n=== 评估模式 ===")
        if args.data_path is not None:
            data_path = args.data_path
        else:
            data_path = workspace.work_dir / "safe_iql_data.pt"
        workspace._train_safe_iql_scheduling_policy(data_path=data_path, eval=True)
    
    print("\n=== 训练完成 ===")


if __name__ == "__main__":
    main()

