#!/usr/bin/env python
#
# Copyright (c) 2024, Flow Matching Research Team
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# 1. Redistributions of source code must retain the above copyright notice,
#  this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#  this list of conditions and the following disclaimer in the documentation
#  and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#  contributors may be used to endorse or promote products derived from
#  this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Unified Flow Matching experiment framework with Hydra support

import sys
import torch
import hydra
import numpy as np
from omegaconf import DictConfig

# Ensure we don't write bytecode
sys.dont_write_bytecode = True

# Add necessary paths
sys.path.append('utils')
sys.path.append('experiments')

from experiments.kitchen_experiment_hydra import KitchenExperimentHydra
from experiments.pusht_experiment_hydra import PushTExperimentHydra
from experiments.mimic_experiment_hydra import MimicExperimentHydra


def set_random_seed(seed: int):
    """Set random seed for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def create_experiment_from_config(cfg: DictConfig):
    """Create experiment instance from Hydra configuration."""
    env_name = cfg.env.name
    
    if env_name == "pusht":
        return PushTExperimentHydra(cfg)
    elif env_name == "kitchen":
        return KitchenExperimentHydra(cfg)
    elif env_name == "mimic":
        return MimicExperimentHydra(cfg)
    else:
        raise ValueError(f"Unknown environment: {env_name}")


def run_testing(experiment, cfg: DictConfig):
    """Run testing process."""
    print(f"Starting testing for {cfg.env.name} experiment...")
    
    if not cfg.checkpoint.load_checkpoint:
        raise ValueError("Checkpoint path must be specified for testing")
    
    # 加载模型
    model = experiment.setup_model()
    checkpoint = torch.load(cfg.checkpoint.load_checkpoint, map_location=cfg.device, weights_only=False)
    
    # 处理PyTorch Lightning格式的checkpoint
    if 'state_dict' in checkpoint:
        state_dict = {}
        for k, v in checkpoint['state_dict'].items():
            if k.startswith('model.'):
                state_dict[k[6:]] = v
            elif k.startswith('nets.'):
                state_dict[k[5:]] = v
            else:
                state_dict[k] = v
        model.load_state_dict(state_dict)
    elif 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    elif 'model' in checkpoint:
        model.load_state_dict(checkpoint['model'])
    else:
        model.load_state_dict(checkpoint)
        
    model.eval()
    
    env = experiment.setup_environment()
    
    test_start_seed = getattr(cfg.execution.testing, 'start_seed', 1000)
    test_episodes = getattr(cfg.execution.testing, 'episodes', 10)
    test_runs_per_episode = getattr(cfg.execution.testing, 'runs_per_episode', 10)
    max_steps = getattr(cfg.env.environment.max_steps, 'max_steps', 280)
    
    print(f"Test configuration: seed_start={test_start_seed}, episodes={test_episodes}, runs_per_episode={test_runs_per_episode}, max_steps={max_steps}")
    
    success_rates = []
    total_rewards_list = []
    reward_rates = []
    
    env_name = cfg.env.name.lower()
    
    if env_name == 'kitchen':
        success_rates, total_rewards_list, reward_rates = _test_kitchen(
            model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg
        )
    elif env_name == 'pusht':
        success_rates, total_rewards_list, reward_rates = _test_pusht(
            model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg
        )
    elif env_name == 'mimic':
        success_rates, total_rewards_list, reward_rates = _test_mimic(
            model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg
        )
    else:
        print(f"Unknown environment type: {env_name}, using generic testing")
        success_rates, total_rewards_list, reward_rates = _test_generic(
            model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg
        )
    
    print("\n" + "="*50)
    print("TEST RESULTS SUMMARY")
    print("="*50)
    if success_rates:
        print(f"Success Rate - Mean: {np.mean(success_rates):.2%}, Std: {np.std(success_rates):.2%}")
    print(f"Total Rewards - Mean: {np.mean(total_rewards_list):.2f}, Std: {np.std(total_rewards_list):.2f}")
    if reward_rates:
        print(f"Reward Rate - Mean: {np.mean(reward_rates):.2%}, Std: {np.std(reward_rates):.2%}")
    
    print("Testing completed successfully!")


def _test_kitchen(model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg):
    """Test Kitchen environment based on flow_kitchen.py example."""
    import numpy as np
    import collections
    import os
    import cv2
    from datetime import datetime
    from tqdm import tqdm
    
    success_rates = []
    total_rewards_list = []
    reward_rates = []
    
    obs_horizon = getattr(cfg.env.dataset, 'obs_horizon', 1)
    pred_horizon = getattr(cfg.env.dataset, 'pred_horizon', 16)
    action_horizon = getattr(cfg.env.dataset, 'action_horizon', 8)
    action_dim = getattr(cfg.env.dataset, 'action_dim', 9)
    
    # 设置MuJoCo渲染配置
    render_mode = 'rgb_array'  # 使用rgb_array模式以便保存帧
    render_frequency = 1   # 渲染频率，每多少步渲染一次
    
    # 设置保存渲染帧的配置
    save_frames = True  # 是否保存渲染帧
    output_dir = "./rendered_frames"  # 保存帧的目录
    frame_format = "png"  # 帧格式: png, jpg, jpeg
    
    print(f"Kitchen test parameters: obs_horizon={obs_horizon}, pred_horizon={pred_horizon}, action_horizon={action_horizon}, action_dim={action_dim}")
    print(f"Rendering enabled: mode={render_mode}, frequency={render_frequency}")
    print(f"Frame saving: {save_frames}, output_dir={output_dir}, format={frame_format}")
    
    # 创建输出目录
    if save_frames:
        os.makedirs(output_dir, exist_ok=True)
        print(f"Created output directory: {output_dir}")
    
    # 检查是否支持human渲染模式
    try:
        # 测试human模式是否可用
        test_env = env
        test_env.render(mode='human')
        test_env.render(mode='rgb_array')  # 确保rgb_array也工作
        human_mode_available = True
    except Exception as e:
        print(f"Human rendering mode not available: {e}")
        print("Falling back to rgb_array mode")
        render_mode = 'rgb_array'
        human_mode_available = False
    
    for epoch in range(test_episodes):
        seed = test_start_seed + epoch
        env.seed(seed)
        
        success_count = 0
        total_episodes = 0
        total_rewards = 0
        
        for pp in range(test_runs_per_episode):
            obs = env.reset()
            
            # 初始渲染
            if render_mode:
                try:
                    env.render(mode=render_mode)
                except Exception as e:
                    print(f"Warning: Failed to render environment: {e}")
                    print("Continuing without rendering...")
                    render_mode = None
            
            obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
            rewards = []
            done = False
            step_idx = 0
            
            with tqdm(total=max_steps, desc=f"Eval Kitchen Epoch{epoch}-Run{pp}") as pbar:
                while not done:
                    x_img = np.stack([x for x in obs_deque])
                    x_img = torch.from_numpy(x_img).to(cfg.device, dtype=torch.float32)
                    
                    with torch.no_grad():
                        x_img_batch = x_img.unsqueeze(0)
                        obs_cond = x_img_batch.flatten(start_dim=1)
                        
                        timehorion = 1 ### 16
                        # import pdb
                        # pdb.set_trace()

                        for i in range(timehorion):
                            noise = torch.rand(1, pred_horizon, action_dim).to(cfg.device)
                            x0 = noise.expand(obs_cond.shape[0], -1, -1)
                            timestep = torch.tensor([i / timehorion]).to(cfg.device)
                            
                            if i == 0:
                                vt = model['noise_pred_net'](x0, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorion + x0)
                            else:
                                vt = model['noise_pred_net'](traj, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorion + traj)
                        
                        naction = traj.detach().to('cpu').numpy()
                        naction = naction[0]
                        action_pred = naction
                        
                        start = obs_horizon - 1
                        end = start + action_horizon
                        action = action_pred[start:end, :]
                        
                        for j in range(len(action)):
                            obs, reward, done, info = env.step(action[j])
                            obs_deque.append(obs)
                            rewards.append(reward)
                            
                            # 渲染环境
                            if render_mode and step_idx % render_frequency == 0:
                                try:
                                    # 获取渲染帧
                                    frame = env.render(mode=render_mode)
                                    
                                    # 保存帧到文件
                                    if save_frames and frame is not None:
                                        # 创建文件名
                                        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                                        filename = f"frame_epoch{epoch}_run{pp}_step{step_idx:04d}_{timestamp}.{frame_format}"
                                        filepath = os.path.join(output_dir, filename)
                                        
                                        # 保存图像
                                        if frame_format.lower() in ['png', 'jpg', 'jpeg']:
                                            # 转换BGR到RGB（如果需要）
                                            if len(frame.shape) == 3 and frame.shape[2] == 3:
                                                # 假设是RGB格式，OpenCV需要BGR
                                                frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                                                cv2.imwrite(filepath, frame_bgr)
                                            else:
                                                cv2.imwrite(filepath, frame)
                                            print(f"Saved frame: {filename}")
                                        else:
                                            print(f"Unsupported frame format: {frame_format}")
                                    
                                except Exception as e:
                                    print(f"Warning: Failed to render environment: {e}")
                                    print("Continuing without rendering...")
                                    render_mode = None
                            
                            step_idx += 1
                            pbar.update(1)
                            pbar.set_postfix(reward=reward)
                            
                            if step_idx > max_steps or sum(rewards) == 4:
                                done = True
                            if done:
                                break
            
            total_episodes += 1
            episode_reward = sum(rewards)
            total_rewards += episode_reward
            if episode_reward == 4:
                success_count += 1
                print(f"Episode {seed}-{pp} succeeded! Total reward: {episode_reward}")
            else:
                print(f"Episode {seed}-{pp} failed. Total reward: {episode_reward}")
        
        success_rates.append(success_count / total_episodes)
        total_rewards_list.append(total_rewards)
        reward_rates.append(total_rewards / (total_episodes * 4))
        
        print(f"\nTest {epoch+1} Summary:")
        print(f"Total Episodes: {total_episodes}")
        print(f"Success Count: {success_count}")
        print(f"Success Rate: {success_rates[-1]:.2%}")
        print(f"Total Rewards: {total_rewards_list[-1]}")
        print(f"Average Reward per Episode: {total_rewards_list[-1] / total_episodes:.2f}")
        print(f"Reward Rate: {reward_rates[-1]:.2%}\n")
    
    return success_rates, total_rewards_list, reward_rates


def _test_pusht(model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg):
    """Test PUSH-T environment based on flow_pusht.py example."""
    import numpy as np
    import collections
    import os
    import cv2
    from datetime import datetime
    from tqdm import tqdm
    
    success_rates = []
    total_rewards_list = []
    reward_rates = []
    
    obs_horizon = getattr(cfg.env.dataset, 'obs_horizon', 1)
    pred_horizon = getattr(cfg.env.dataset, 'pred_horizon', 16)
    action_horizon = getattr(cfg.env.dataset, 'action_horizon', 8)
    action_dim = getattr(cfg.env.dataset, 'action_dim', 2)
    
    # 设置渲染配置
    render_mode = 'rgb_array'  # 使用rgb_array模式以便保存帧
    render_frequency = 1   # 渲染频率，每多少步渲染一次
    
    # 设置保存渲染帧的配置
    save_frames = True  # 是否保存渲染帧
    output_dir = "./rendered_frames_pusht"  # 保存帧的目录
    frame_format = "png"  # 帧格式: png, jpg, jpeg
    
    print(f"PUSH-T test parameters: obs_horizon={obs_horizon}, pred_horizon={pred_horizon}, action_horizon={action_horizon}, action_dim={action_dim}")
    print(f"Rendering enabled: mode={render_mode}, frequency={render_frequency}")
    print(f"Frame saving: {save_frames}, output_dir={output_dir}, format={frame_format}")
    
    # 创建输出目录
    if save_frames:
        os.makedirs(output_dir, exist_ok=True)
        print(f"Created output directory: {output_dir}")
    
    for epoch in range(test_episodes):
        seed = test_start_seed + epoch
        env.seed(seed)
        
        total_rewards = 0
        
        for pp in range(test_runs_per_episode):
            obs, info = env.reset()
            
            obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
            rewards = []
            done = False
            step_idx = 0
            
            with tqdm(total=max_steps, desc=f"Eval PUSH-T Epoch{epoch}-Run{pp}") as pbar:
                while not done:
                    x_img = np.stack([x['image'] for x in obs_deque])
                    x_pos = np.stack([x['agent_pos'] for x in obs_deque])
                    
                    x_img = torch.from_numpy(x_img).to(cfg.device, dtype=torch.float32)
                    x_pos = torch.from_numpy(x_pos).to(cfg.device, dtype=torch.float32)
                    
                    with torch.no_grad():
                        if 'vision_encoder' in model:
                            x_img_batch = x_img.squeeze(0).unsqueeze(0)
                            image_features = model['vision_encoder'](x_img_batch)
                            x_pos_batch = x_pos.squeeze(0).unsqueeze(0)
                            obs_cond = torch.cat([image_features, x_pos_batch], dim=-1)
                        else:
                            x_img_batch = x_img.unsqueeze(0)
                            x_pos_batch = x_pos.unsqueeze(0)
                            obs_cond = torch.cat([x_img_batch.flatten(start_dim=1), x_pos_batch.flatten(start_dim=1)], dim=-1)
                        
                        timehorizon = 1
                        for i in range(timehorizon):
                            noise = torch.rand(1, pred_horizon, action_dim).to(cfg.device)
                            x0 = noise.expand(obs_cond.shape[0], -1, -1)
                            timestep = torch.tensor([i / timehorizon]).to(cfg.device)
                            
                            if i == 0:
                                vt = model['noise_pred_net'](x0, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + x0)
                            else:
                                vt = model['noise_pred_net'](traj, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + traj)
                        
                        naction = traj.detach().to('cpu').numpy()
                        naction = naction[0]
                        action_pred = naction
                        
                        start = obs_horizon - 1
                        end = start + action_horizon
                        action = action_pred[start:end, :]
                        
                        for j in range(len(action)):
                            obs, reward, done, _, info = env.step(action[j])
                            obs_deque.append(obs)
                            rewards.append(reward)
                            
                            # 渲染环境
                            if render_mode and step_idx % render_frequency == 0:
                                try:
                                    # 获取渲染帧
                                    frame = env.render(mode=render_mode)
                                    
                                    # 保存帧到文件
                                    if save_frames and frame is not None:
                                        # 创建文件名
                                        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                                        filename = f"frame_epoch{epoch}_run{pp}_step{step_idx:04d}_{timestamp}.{frame_format}"
                                        filepath = os.path.join(output_dir, filename)
                                        
                                        # 保存图像
                                        if frame_format.lower() in ['png', 'jpg', 'jpeg']:
                                            # 转换BGR到RGB（如果需要）
                                            if len(frame.shape) == 3 and frame.shape[2] == 3:
                                                # 假设是RGB格式，OpenCV需要BGR
                                                frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                                                cv2.imwrite(filepath, frame_bgr)
                                            else:
                                                cv2.imwrite(filepath, frame)
                                            print(f"Saved frame: {filename}")
                                        else:
                                            print(f"Unsupported frame format: {frame_format}")
                                    
                                except Exception as e:
                                    print(f"Warning: Failed to render environment: {e}")
                                    print("Continuing without rendering...")
                                    render_mode = None
                            
                            step_idx += 1
                            pbar.update(1)
                            pbar.set_postfix(reward=reward)
                            
                            if step_idx > max_steps:
                                done = True
                            if done:
                                break
            
            episode_reward = sum(rewards)
            total_rewards += episode_reward
            print(f"Episode {seed}-{pp} completed. Total reward: {episode_reward}")
        
        total_rewards_list.append(total_rewards)
        print(f"\nTest {epoch+1} Summary:")
        print(f"Total Episodes: {test_runs_per_episode}")
        print(f"Total Rewards: {total_rewards_list[-1]}")
        print(f"Average Reward per Episode: {total_rewards_list[-1] / test_runs_per_episode:.2f}\n")
    
    return success_rates, total_rewards_list, reward_rates


def _test_mimic(model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg):
    """Test Mimic environment based on flow_mimic.py example."""
    import numpy as np
    import collections
    import os
    import cv2
    from datetime import datetime
    from tqdm import tqdm
    
    success_rates = []
    total_rewards_list = []
    reward_rates = []
    
    obs_horizon = getattr(cfg.env.dataset, 'obs_horizon', 1)
    pred_horizon = getattr(cfg.env.dataset, 'pred_horizon', 16)
    action_horizon = getattr(cfg.env.dataset, 'action_horizon', 8)
    action_dim = getattr(cfg.env.dataset, 'action_dim', 20)
    
    # 设置渲染配置
    render_mode = 'rgb_array'  # 使用rgb_array模式以便保存帧
    render_frequency = 1   # 渲染频率，每多少步渲染一次
    
    # 设置保存渲染帧的配置
    save_frames = True  # 是否保存渲染帧
    output_dir = "./rendered_frames_mimic"  # 保存帧的目录
    frame_format = "png"  # 帧格式: png, jpg, jpeg
    
    def undo_transform_action(action):
        raw_shape = action.shape
        if raw_shape[-1] == 20:
            action = action.reshape(-1, 2, 10)
            uaction = action[0, 0, :7]
            uaction = np.concatenate([uaction, action[0, 1, :7]])
        else:
            uaction = action
        return uaction
    
    print(f"Mimic test parameters: obs_horizon={obs_horizon}, pred_horizon={pred_horizon}, action_horizon={action_horizon}, action_dim={action_dim}")
    print(f"Rendering enabled: mode={render_mode}, frequency={render_frequency}")
    print(f"Frame saving: {save_frames}, output_dir={output_dir}, format={frame_format}")
    
    # 创建输出目录
    if save_frames:
        os.makedirs(output_dir, exist_ok=True)
        print(f"Created output directory: {output_dir}")
    
    for epoch in range(test_episodes):
        seed = test_start_seed + epoch
        env.seed(seed)
        
        total_rewards = 0
        
        for pp in range(test_runs_per_episode):
            obs = env.reset()
            
            obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
            rewards = []
            done = False
            step_idx = 0
            
            with tqdm(total=max_steps, desc=f"Eval Mimic Epoch{epoch}-Run{pp}") as pbar:
                while not done:
                    x_img = np.stack([x for x in obs_deque])
                    x_img = torch.from_numpy(x_img).to(cfg.device, dtype=torch.float32)
                    
                    with torch.no_grad():
                        obs_cond = x_img.flatten(start_dim=1)
                        
                        timehorizon = 1
                        for i in range(timehorizon):
                            noise = torch.rand(1, pred_horizon, action_dim).to(cfg.device)
                            x0 = noise.expand(x_img.shape[0], -1, -1)
                            timestep = torch.tensor([i / timehorizon]).to(cfg.device)
                            
                            if i == 0:
                                vt = model['noise_pred_net'](x0, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + x0)
                            else:
                                vt = model['noise_pred_net'](traj, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + traj)
                        
                        naction = traj.detach().to('cpu').numpy()
                        naction = naction[0]
                        action_pred = naction
                        
                        start = obs_horizon - 1
                        end = start + action_horizon
                        action = action_pred[start:end, :]
                        
                        for j in range(len(action)):
                            env_action = undo_transform_action(action[j])
                            obs, reward, done, info = env.step(env_action)
                            obs_deque.append(obs)
                            rewards.append(reward)
                            
                            # 渲染环境
                            if render_mode and step_idx % render_frequency == 0:
                                try:
                                    # 获取渲染帧
                                    frame = env.render(mode=render_mode)
                                    
                                    # 保存帧到文件
                                    if save_frames and frame is not None:
                                        # 创建文件名
                                        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                                        filename = f"frame_epoch{epoch}_run{pp}_step{step_idx:04d}_{timestamp}.{frame_format}"
                                        filepath = os.path.join(output_dir, filename)
                                        
                                        # 保存图像
                                        if frame_format.lower() in ['png', 'jpg', 'jpeg']:
                                            # 转换BGR到RGB（如果需要）
                                            if len(frame.shape) == 3 and frame.shape[2] == 3:
                                                # 假设是RGB格式，OpenCV需要BGR
                                                frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                                                cv2.imwrite(filepath, frame_bgr)
                                            else:
                                                cv2.imwrite(filepath, frame)
                                            print(f"Saved frame: {filename}")
                                        else:
                                            print(f"Unsupported frame format: {frame_format}")
                                    
                                except Exception as e:
                                    print(f"Warning: Failed to render environment: {e}")
                                    print("Continuing without rendering...")
                                    render_mode = None
                            
                            step_idx += 1
                            pbar.update(1)
                            pbar.set_postfix(reward=reward)
                            
                            if step_idx > max_steps or reward == 1:
                                done = True
                            if done:
                                break
            
            episode_reward = sum(rewards)
            total_rewards += episode_reward
            print(f"Episode {seed}-{pp} completed. Total reward: {episode_reward}")
        
        total_rewards_list.append(total_rewards)
        print(f"\nTest {epoch+1} Summary:")
        print(f"Total Episodes: {test_runs_per_episode}")
        print(f"Total Rewards: {total_rewards_list[-1]}")
        print(f"Average Reward per Episode: {total_rewards_list[-1] / test_runs_per_episode:.2f}\n")
    
    return success_rates, total_rewards_list, reward_rates


def _test_generic(model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg):
    """Generic test function for unknown environment types."""
    import numpy as np
    import collections
    from tqdm import tqdm
    
    success_rates = []
    total_rewards_list = []
    reward_rates = []
    
    obs_horizon = getattr(cfg, 'obs_horizon', 1)
    pred_horizon = getattr(cfg, 'pred_horizon', 16)
    action_horizon = getattr(cfg, 'action_horizon', 8)
    action_dim = getattr(cfg, 'action_dim', 9)
    
    print(f"Generic test parameters: obs_horizon={obs_horizon}, pred_horizon={pred_horizon}, action_horizon={action_horizon}, action_dim={action_dim}")
    
    for epoch in range(test_episodes):
        seed = test_start_seed + epoch
        if hasattr(env, 'seed'):
            env.seed(seed)
        
        total_rewards = 0
        
        for pp in range(test_runs_per_episode):
            try:
                obs = env.reset()
            except:
                obs, _ = env.reset()
            
            obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
            rewards = []
            done = False
            step_idx = 0
            
            with tqdm(total=max_steps, desc=f"Eval Generic Epoch{epoch}-Run{pp}") as pbar:
                while not done:
                    x_img = np.stack([x for x in obs_deque])
                    x_img = torch.from_numpy(x_img).to(cfg.device, dtype=torch.float32)
                    
                    with torch.no_grad():
                        obs_cond = x_img.flatten(start_dim=1)
                        
                        timehorizon = 1
                        for i in range(timehorizon):
                            noise = torch.rand(1, pred_horizon, action_dim).to(cfg.device)
                            x0 = noise.expand(x_img.shape[0], -1, -1)
                            timestep = torch.tensor([i / timehorizon]).to(cfg.device)
                            
                            if i == 0:
                                vt = model(x0, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + x0)
                            else:
                                vt = model(traj, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + traj)
                        
                        naction = traj.detach().to('cpu').numpy()
                        naction = naction[0]
                        action_pred = naction
                        
                        start = obs_horizon - 1
                        end = start + action_horizon
                        action = action_pred[start:end, :]
                        
                        for j in range(len(action)):
                            try:
                                obs, reward, done, info = env.step(action[j])
                            except:
                                obs, reward, done, _, info = env.step(action[j])
                            
                            obs_deque.append(obs)
                            rewards.append(reward)
                            
                            step_idx += 1
                            pbar.update(1)
                            pbar.set_postfix(reward=reward)
                            
                            if step_idx > max_steps:
                                done = True
                            if done:
                                break
            
            episode_reward = sum(rewards)
            total_rewards += episode_reward
            print(f"Episode {seed}-{pp} completed. Total reward: {episode_reward}")

        total_rewards_list.append(total_rewards)
        print(f"\nTest {epoch+1} Summary:")
        print(f"Total Episodes: {test_runs_per_episode}")
        print(f"Total Rewards: {total_rewards_list[-1]}")
        print(f"Average Reward per Episode: {total_rewards_list[-1] / test_runs_per_episode:.2f}\n")

    return success_rates, total_rewards_list, reward_rates


@hydra.main(version_base=None, config_path="conf", config_name="config-test")
def main(cfg: DictConfig):
    """Main function simplified for testing purposes."""
    print("--- Starting Flow Matching Evaluation ---")
    
    if not cfg.checkpoint.load_checkpoint:
        raise ValueError("Checkpoint path must be specified for testing via 'checkpoint.load_checkpoint=/path/to/ckpt'")

    print(f"Environment: {cfg.env.name}")
    print(f"Checkpoint: {cfg.checkpoint.load_checkpoint}")

    set_random_seed(cfg.seed)
    
    experiment = create_experiment_from_config(cfg)
    
    run_testing(experiment, cfg)
    
    print("--- Evaluation Completed ---")


if __name__ == "__main__":
    main()
