import argparse
import gym
import numpy as np
import os
import torch
import yaml
import d4rl
import cv2

# 你的agent导入
from agents.diffusion_ql import Diffusion_QL
from agents.diffusion_bc import Diffusion_BC
from agents.flow_ql import Flow_QL
from agents.mmd_ql import MMD_QL
from agents.mmd_bc import MMD_BC

available_agents = {
    'dql': Diffusion_QL,
    'mmd_ql': MMD_QL,
    'mmd_bc': MMD_BC,
    'diffusion_bc': Diffusion_BC,
    'flow_ql': Flow_QL,
}

def load_trained_agent(model_dir, config_path):
    """Load trained agent from saved weights"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    temp_env = gym.make(config['env_name'])
    state_dim = temp_env.observation_space.shape[0]
    action_dim = temp_env.action_space.shape[0]
    max_action = float(temp_env.action_space.high[0])
    temp_env.close()
    
    model_args = {}
    if 'model_args' in config and config['model_args'] is not None:
        for item in config['model_args']:
            model_args.update(item)
    
    model_args.update({
        'state_dim': state_dim,
        'action_dim': action_dim,
        'max_action': max_action,
        'device': config.get('device', 'cpu')
    })
    
    if 'gn' in model_args:
        model_args['grad_norm'] = model_args.pop('gn')
    
    agent = available_agents[config['model']](**model_args)
    agent.load_model(model_dir)
    
    return agent, config

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, required=True)
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--num_trajectories", type=int, default=5)
    parser.add_argument("--max_steps", type=int, default=1000)
    parser.add_argument("--output_dir", type=str, default="./videos")
    args = parser.parse_args()
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    print("Loading trained agent...")
    agent, config = load_trained_agent(args.model_dir, args.config_path)
    
    for i in range(args.num_trajectories):
        print(f"Generating trajectory {i+1}/{args.num_trajectories}...")
        
        # 创建环境，使用rgb_array模式进行渲染
        env = gym.make(config['env_name'], render_mode='rgb_array')
        
        # 设置视频写入器
        video_path = os.path.join(args.output_dir, f"trajectory_{i+1}.mp4")
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = None
        
        reset_result = env.reset()
        state = reset_result[0] if isinstance(reset_result, tuple) else reset_result
        
        for step in range(args.max_steps):
            # 渲染当前帧
            frame = env.render(mode='rgb_array')
            
            # 初始化视频写入器（使用第一帧的尺寸）
            if out is None:
                height, width = frame.shape[:2]
                out = cv2.VideoWriter(video_path, fourcc, 30.0, (width, height))
            
            # 写入帧（OpenCV使用BGR，需要转换）
            frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            out.write(frame_bgr)
            
            action = agent.sample_action(np.array(state))
            step_result = env.step(action)
            state, reward, done = step_result[0], step_result[1], step_result[2]
            
            if done:
                break
        
        # 释放资源
        if out is not None:
            out.release()
        env.close()
        print(f"Saved video for trajectory {i+1}")

    print(f"\nAll videos saved to {args.output_dir}")

if __name__ == "__main__":
    main()
