"""
评估 Safe-IQL Policy

这个脚本展示了如何使用训练好的 Safe-IQL Policy 评估 LIBERO 任务。
"""

from config import PolicyConfig
from policy_workspace import PolicyWorkspace
import argparse
import os
import json
from pathlib import Path


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='评估 Safe-IQL Policy')
    
    # 实验参数
    parser.add_argument('--output_dir', type=str, default='./exp_output',
                        help='输出目录')
    parser.add_argument('--exp_name', type=str, default='safe_iql',
                        help='实验名称')
    
    # 评估参数
    parser.add_argument('--task', type=str, default='libero_spatial_no_noops')
    
    # 策略服务器参数
    parser.add_argument('--policy_host', type=str, default='0.0.0.0',
                        help='策略服务器地址')
    parser.add_argument('--policy_port', type=int, default=8000,
                        help='策略服务器端口')
    
    # 配置路径（可选，如果提供则从文件加载）
    parser.add_argument('--config_path', type=str, default=None,
                        help='配置文件路径（如果为 None，则从工作目录加载）')
    
    # 数据路径（可选）
    parser.add_argument('--data_path', type=str, default=None,
                        help='Safe-IQL 数据路径（如果为 None，则使用默认路径）')
    
    parser.add_argument('--client_port', type=int, default=None)
    args = parser.parse_args()
    
    # 确定工作目录
    task = args.task
    work_dir = Path(args.output_dir) / task / args.exp_name
    
    # 加载配置
    if args.config_path is not None:
        config_path = Path(args.config_path)
    else:
        config_path = work_dir / "config.json"
    
    if config_path.exists():
        print(f"从 {config_path} 加载配置...")
        with open(config_path, "r") as f:
            config_dict = json.load(f)
        config = PolicyConfig.from_dict(config_dict)
    else:
        raise FileNotFoundError(f"配置文件 {config_path} 不存在")
    
    # 更新配置（禁用 wandb 等）
    config.log.use_wandb = False
    config.log.use_tensorboard = False
    
    # 创建评估工作空间
    print(f"\n创建工作空间: {work_dir}")
    workspace = PolicyWorkspace(config, work_dir=str(work_dir), train=False)
    
    # 执行评估
    print(f"\n=== 开始评估 Safe-IQL Policy ===")
    print(f"任务套件: {args.task}")
    print(f"工作目录: {work_dir}")
    print("=" * 60)
    
    # 调用训练函数（eval=True 模式）
    if args.data_path is not None:
        data_path = args.data_path
    else:
        data_path = workspace.work_dir / "safe_iql_data.pt"
    
    workspace._train_safe_iql_scheduling_policy(
        data_path=data_path,
        eval=True,
        as_client=args.client_port is not None,
        client_port=args.client_port
    )
    

if __name__ == "__main__":
    main()

