#!/usr/bin/env python
#
# Copyright (c) 2024, Flow Matching Research Team
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# 1. Redistributions of source code must retain the above copyright notice,
#  this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#  this list of conditions and the following disclaimer in the documentation
#  and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#  contributors may be used to endorse or promote products derived from
#  this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
# IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Unified Flow Matching experiment framework with Hydra support

import os
import sys
import torch
import hydra
from omegaconf import DictConfig
import wandb
from utils.experiments.kitchen_experiment_hydra import KitchenExperimentHydra
from utils.experiments.pusht_experiment_hydra import PushTExperimentHydra
from utils.experiments.mimic_experiment_hydra import MimicExperimentHydra
import numpy as np
from datetime import datetime
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

# Ensure we don't write bytecode
sys.dont_write_bytecode = True

# Add necessary paths
sys.path.append('utils')
sys.path.append('utils/experiments')


from utils.experiments.pusht_experiment_hydra import PushTExperimentHydra


# Import other experiment classes as needed
from utils.experiments.kitchen_experiment_hydra import KitchenExperimentHydra
# from experiments.mimic_experiment_hydra import MimicExperimentHydra


def set_random_seed(seed: int):
    """Set random seed for reproducibility."""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def create_experiment_from_config(cfg: DictConfig):
    """Create experiment instance from Hydra configuration."""
    env_name = cfg.env.name
    
    if env_name == "pusht":
        return PushTExperimentHydra(cfg)
    elif env_name == "kitchen":
        return KitchenExperimentHydra(cfg)
    elif env_name == "mimic":
        from utils.experiments.mimic_experiment_hydra import MimicExperimentHydra
        return MimicExperimentHydra(cfg)
    else:
        raise ValueError(f"Unknown environment: {env_name}")


def print_experiment_info(experiment, cfg: DictConfig):
    """Print experiment information."""
    print("=" * 60)
    print("Experiment Configuration")
    print("=" * 60)
    print(f"Environment: {cfg.env.name}")
    print(f"Model Type: {cfg.model.type}")
    print(f"Dataset Path: {cfg.env.dataset.path}")
    print(f"Training Epochs: {cfg.training.epochs}")
    print(f"Batch Size: {cfg.training.batch_size}")
    print(f"Learning Rate: {cfg.training.learning_rate}")
    print(f"Device: {cfg.device}")
    print(f"Checkpoint Dir: {cfg.checkpoint.save_dir}")
    print("=" * 60)
    
    # Print detailed configuration if available
    if hasattr(experiment, 'config'):
        experiment.config.print_config()


def run_training(experiment, cfg: DictConfig):
    """Run training process."""
    print(f"Starting training for {cfg.env.name} experiment...")
    
    # Setup dataset and model
    # dataloader = experiment.setup_dataset()
    # model = experiment.setup_model()
    
    # Call the actual training method from the experiment class
    experiment.train_fm()
    
    print("Training completed successfully!")


def run_testing(experiment, cfg: DictConfig):
    """Run testing process."""
    print(f"Starting testing for {cfg.env.name} experiment...")
    
    if not cfg.checkpoint.load_checkpoint:
        raise ValueError("Checkpoint path must be specified for testing")
    
    # Load model
    model = experiment.setup_model()
    checkpoint = torch.load(cfg.checkpoint.load_checkpoint, map_location=cfg.device, weights_only=False)
    
    # # Handle different checkpoint structures
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    elif 'model' in checkpoint:
        model.load_state_dict(checkpoint['model'])
    else:
        # Assume the checkpoint is the state dict itself
        model.load_state_dict(checkpoint)
        
    #--------------------------------------------------------
    # # # 检查检查点结构并适配
    # if isinstance(checkpoint, dict):
    #     if 'noise_pred_net' in checkpoint:
    #         # 标准格式
    #         state_dict = checkpoint['noise_pred_net']
    #     elif 'model_state_dict' in checkpoint:
    #         # 另一种常见格式
    #         state_dict = checkpoint['model_state_dict']
    #     else:
    #         # 直接是权重字典
    #         state_dict = checkpoint
    # else:
    #     # 如果checkpoint本身就是state_dict
    #     state_dict = checkpoint

    #     # 检查是否需要添加前缀
    # if not any(k.startswith('noise_pred_net.') for k in state_dict.keys()):
    #     # 添加noise_pred_net前缀
    #     state_dict = {f"noise_pred_net.{k}": v for k, v in state_dict.items()}
        
    # model.load_state_dict(state_dict)
    # --------------------------------------------------------
    

    model.eval()
    
    # Setup environment
    env = experiment.setup_environment()
    
    # Get test configuration
    test_start_seed = getattr(cfg.execution.testing, 'start_seed', 1000)
    test_episodes = getattr(cfg.execution.testing, 'episodes', 1)
    test_runs_per_episode = getattr(cfg.execution.testing, 'runs_per_episode', 10)
    max_steps = getattr(cfg.execution.testing, 'max_steps', 280)
    
    print(f"Test configuration: seed_start={test_start_seed}, episodes={test_episodes}, runs_per_episode={test_runs_per_episode}, max_steps={max_steps}")
    
    # Initialize metrics storage
    success_rates = []
    total_rewards_list = []
    reward_rates = []
    
    # Run testing based on environment type
    env_name = cfg.env.name.lower()
    
    if env_name == 'kitchen':
        success_rates, total_rewards_list, reward_rates = _test_kitchen(
            model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg
        )
    elif env_name == 'pusht':
        success_rates, total_rewards_list, reward_rates = _test_pusht(
            model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg
        )
    elif env_name == 'mimic':
        success_rates, total_rewards_list, reward_rates = _test_mimic(
            model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg
        )
    else:
        print(f"Unknown environment type: {env_name}, using generic testing")
        success_rates, total_rewards_list, reward_rates = _test_generic(
            model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg
        )
    
    # Print test results
    print("\n" + "="*50)
    print("TEST RESULTS SUMMARY")
    print("="*50)
    if success_rates:
        print(f"Success Rate - Mean: {np.mean(success_rates):.2%}, Std: {np.std(success_rates):.2%}")
    print(f"Total Rewards - Mean: {np.mean(total_rewards_list):.2f}, Std: {np.std(total_rewards_list):.2f}")
    if reward_rates:
        print(f"Reward Rate - Mean: {np.mean(reward_rates):.2%}, Std: {np.std(reward_rates):.2%}")
    
    print("Testing completed successfully!")


def run_mle_finetuning(experiment, cfg: DictConfig):
    """Run MLE fine-tuning process."""
    print(f"Starting MLE fine-tuning for {cfg.env.name} experiment...")
    
    # Setup dataset and model
    dataloader = experiment.setup_dataset()
    model = experiment.setup_model()
    
    if cfg.checkpoint.load_checkpoint:
        # Load from checkpoint if specified
        print(f"Loading checkpoint from: {cfg.checkpoint.load_checkpoint}")
        checkpoint = torch.load(cfg.checkpoint.load_checkpoint, map_location=cfg.device, weights_only=False)
        
        # Handle different checkpoint structures
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        elif 'model' in checkpoint:
            model.load_state_dict(checkpoint['model'])
        else:
            # Assume the checkpoint is the state dict itself
            model.load_state_dict(checkpoint)
        
        print("Checkpoint loaded successfully!")
    else:
        # Create new model and start training from scratch
        print("No checkpoint specified, starting MLE training from scratch with new model")
    
    # Call the actual MLE training method from the experiment class
    experiment.train_mle()
    
    print("MLE fine-tuning completed successfully!")


def _test_kitchen(model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg):
    """Test Kitchen environment based on flow_kitchen.py example."""
    import numpy as np
    import collections
    from tqdm import tqdm
    
    success_rates = []
    total_rewards_list = []
    reward_rates = []
    
    # Kitchen specific parameters
    obs_horizon = getattr(cfg.env.dataset, 'obs_horizon', 1)  # 2
    pred_horizon = getattr(cfg.env.dataset, 'pred_horizon', 16) # 16
    action_horizon = getattr(cfg.env.dataset, 'action_horizon', 8)
    action_dim = getattr(cfg.env.dataset, 'action_dim', 9)
    
    print(f"Kitchen test parameters: obs_horizon={obs_horizon}, pred_horizon={pred_horizon}, action_horizon={action_horizon}, action_dim={action_dim}")
    
    for epoch in range(test_episodes):
        seed = test_start_seed + epoch
        env.seed(seed)
        
        success_count = 0
        total_episodes = 0
        total_rewards = 0
        
        for pp in range(test_runs_per_episode):
            obs = env.reset()
            
            obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
            rewards = []
            done = False
            step_idx = 0
            
            with tqdm(total=max_steps, desc=f"Eval Kitchen Epoch{epoch}-Run{pp}") as pbar:
                while not done:
                    x_img = np.stack([x for x in obs_deque])
                    x_img = torch.from_numpy(x_img).to(cfg.device, dtype=torch.float32)
                    
                    # Debug: print observation shapes
                    # if step_idx == 0:
                    #     print(f"Debug - x_img shape: {x_img.shape}")
                    #     print(f"Debug - obs_deque[0] shape: {obs_deque[0].shape}")
                    #     print(f"Debug - obs_deque[0] type: {type(obs_deque[0])}")
                    #     if hasattr(obs_deque[0], 'keys'):
                    #         print(f"Debug - obs_deque[0] keys: {obs_deque[0].keys()}")
                    
                    # Infer action using Flow Matching
                    with torch.no_grad():
                        # Fix: obs_cond should be (batch_size, obs_horizon * feature_dim)
                        # x_img shape is (obs_horizon, feature_dim), need to add batch dimension
                        x_img_batch = x_img.unsqueeze(0)  # Add batch dimension: (1, obs_horizon, feature_dim)
                        obs_cond = x_img_batch.flatten(start_dim=1)  # (1, obs_horizon * feature_dim)
                        
                        # if step_idx == 0:
                        #     print(f"Debug - x_img_batch shape: {x_img_batch.shape}")
                        #     print(f"Debug - obs_cond shape: {obs_cond.shape}")
                        #     print(f"Debug - expected global_cond_dim: {cfg.env.dataset.vision_feature_dim}")
                        
                        # Use timehorion = 16 as in the reference code
                        timehorion = 16
                        for i in range(timehorion):
                            noise = torch.rand(1, pred_horizon, action_dim).to(cfg.device)
                            x0 = noise.expand(obs_cond.shape[0], -1, -1)  # Use obs_cond.shape[0] for batch size
                            timestep = torch.tensor([i / timehorion]).to(cfg.device)
                            
                            if i == 0:
                                vt = model['noise_pred_net'](x0, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorion + x0)
                            else:
                                vt = model['noise_pred_net'](traj, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorion + traj)
                        
                        naction = traj.detach().to('cpu').numpy()
                        naction = naction[0]
                        action_pred = naction
                        
                        # Take action_horizon number of actions
                        start = obs_horizon - 1
                        end = start + action_horizon
                        action = action_pred[start:end, :]
                        
                        for j in range(len(action)):
                            obs, reward, done, info = env.step(action[j])
                            obs_deque.append(obs)
                            rewards.append(reward)
                            
                            step_idx += 1
                            pbar.update(1)
                            pbar.set_postfix(reward=reward)
                            
                            if step_idx > max_steps or sum(rewards) == 4:
                                done = True
                            if done:
                                break
            
            total_episodes += 1
            episode_reward = sum(rewards)
            total_rewards += episode_reward
            if episode_reward == 4:
                success_count += 1
                print(f"Episode {seed}-{pp} succeeded! Total reward: {episode_reward}")
            else:
                print(f"Episode {seed}-{pp} failed. Total reward: {episode_reward}")
        
        # Record metrics for this test (following reference code structure)
        success_rates.append(success_count / total_episodes)
        total_rewards_list.append(total_rewards)
        reward_rates.append(total_rewards / (total_episodes * 4))
        
        print(f"\nTest {epoch+1} Summary:")
        print(f"Total Episodes: {total_episodes}")
        print(f"Success Count: {success_count}")
        print(f"Success Rate: {success_rates[-1]:.2%}")
        print(f"Total Rewards: {total_rewards_list[-1]}")
        print(f"Average Reward per Episode: {total_rewards_list[-1] / total_episodes:.2f}")
        print(f"Reward Rate: {reward_rates[-1]:.2%}\n")
    
    return success_rates, total_rewards_list, reward_rates


def _test_pusht(model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg):
    """Test PUSH-T environment based on flow_pusht.py example."""
    import numpy as np
    import collections
    from tqdm import tqdm
    
    success_rates = []
    total_rewards_list = []
    reward_rates = []
    
    # PUSH-T specific parameters
    obs_horizon = getattr(cfg.env.dataset, 'obs_horizon', 1)
    pred_horizon = getattr(cfg.env.dataset, 'pred_horizon', 16)
    action_horizon = getattr(cfg.env.dataset, 'action_horizon', 8)
    action_dim = getattr(cfg.env.dataset, 'action_dim', 2)
    
    print(f"PUSH-T test parameters: obs_horizon={obs_horizon}, pred_horizon={pred_horizon}, action_horizon={action_horizon}, action_dim={action_dim}")
    
    for epoch in range(test_episodes):
        seed = test_start_seed + epoch
        env.seed(seed)
        
        total_rewards = 0
        
        for pp in range(test_runs_per_episode):
            obs, info = env.reset()
            
            obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
            rewards = []
            done = False
            step_idx = 0
            
            with tqdm(total=max_steps, desc=f"Eval PUSH-T Epoch{epoch}-Run{pp}") as pbar:
                while not done:
                    x_img = np.stack([x['image'] for x in obs_deque])
                    x_pos = np.stack([x['agent_pos'] for x in obs_deque])
                    
                    x_img = torch.from_numpy(x_img).to(cfg.device, dtype=torch.float32)
                    x_pos = torch.from_numpy(x_pos).to(cfg.device, dtype=torch.float32)
                    
                    # Debug: print observation shapes for PUSH-T
                    # if step_idx == 0:
                        # print(f"Debug PUSH-T - x_img shape: {x_img.shape}")
                        # print(f"Debug PUSH-T - x_pos shape: {x_pos.shape}")
                        # print(f"Debug PUSH-T - obs_deque[0] keys: {obs_deque[0].keys()}")
                    
                    # Infer action using Flow Matching
                    with torch.no_grad():
                        # For PUSH-T, we need to handle both image and position observations
                        # Use vision encoder to process image data
                        # x_img shape: [obs_horizon, channels, height, width] = [1, 3, 96, 96]
                        # x_pos shape: [obs_horizon, pos_dim] = [1, 2]
                        
                        # Process image through vision encoder if available
                        if 'vision_encoder' in model:
                            # ResNet18 expects [batch, channels, height, width]
                            # Since obs_horizon=1, we can squeeze the first dimension
                            x_img_batch = x_img.squeeze(0)  # Remove obs_horizon dimension: [3, 96, 96]
                            x_img_batch = x_img_batch.unsqueeze(0)  # Add batch dimension: [1, 3, 96, 96]
                            
                            # ResNet18 outputs 512 features
                            image_features = model['vision_encoder'](x_img_batch)  # [1, 512]
                            
                            # Concatenate with position features
                            x_pos_batch = x_pos.squeeze(0).unsqueeze(0)  # [1, 2]
                            obs_cond = torch.cat([image_features, x_pos_batch], dim=-1)  # [1, 514]
                        else:
                            # Fallback: flatten image directly (not recommended)
                            x_img_batch = x_img.unsqueeze(0)  # Add batch dimension
                            x_pos_batch = x_pos.unsqueeze(0)  # Add batch dimension
                            obs_cond = torch.cat([x_img_batch.flatten(start_dim=1), x_pos_batch.flatten(start_dim=1)], dim=-1)
                        
                        # if step_idx == 0:
                        #     print(f"Debug PUSH-T - x_img_batch shape: {x_img_batch.shape}")
                        #     print(f"Debug PUSH-T - x_pos_batch shape: {x_pos_batch.shape}")
                        #     print(f"Debug PUSH-T - obs_cond shape: {obs_cond.shape}")
                        #     print(f"Debug PUSH-T - expected global_cond_dim: {cfg.env.dataset.vision_feature_dim}")
                        
                        timehorizon = 1  # PUSH-T uses single step prediction
                        for i in range(timehorizon):
                            noise = torch.rand(1, pred_horizon, action_dim).to(cfg.device)
                            x0 = noise.expand(obs_cond.shape[0], -1, -1)  # Use obs_cond.shape[0] for batch size
                            timestep = torch.tensor([i / timehorizon]).to(cfg.device)
                            
                            if i == 0:
                                vt = model['noise_pred_net'](x0, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + x0)
                            else:
                                vt = model['noise_pred_net'](traj, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + traj)
                        
                        naction = traj.detach().to('cpu').numpy()
                        naction = naction[0]
                        action_pred = naction
                        
                        # Take action_horizon number of actions
                        start = obs_horizon - 1
                        end = start + action_horizon
                        action = action_pred[start:end, :]
                        
                        for j in range(len(action)):
                            obs, reward, done, _, info = env.step(action[j])
                            obs_deque.append(obs)
                            rewards.append(reward)
                            
                            step_idx += 1
                            pbar.update(1)
                            pbar.set_postfix(reward=reward)
                            
                            if step_idx > max_steps:
                                done = True
                            if done:
                                break
            
            episode_reward = sum(rewards)
            total_rewards += episode_reward
            print(f"Episode {seed}-{pp} completed. Total reward: {episode_reward}")
        
        # Record metrics for this test
        total_rewards_list.append(total_rewards)
        print(f"\nTest {epoch+1} Summary:")
        print(f"Total Episodes: {test_runs_per_episode}")
        print(f"Total Rewards: {total_rewards_list[-1]}")
        print(f"Average Reward per Episode: {total_rewards_list[-1] / test_runs_per_episode:.2f}\n")
    
    return success_rates, total_rewards_list, reward_rates


def _test_mimic(model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg):
    """Test Mimic environment based on flow_mimic.py example."""
    import numpy as np
    import collections
    from tqdm import tqdm
    
    success_rates = []
    total_rewards_list = []
    reward_rates = []
    
    # Mimic specific parameters
    obs_horizon = getattr(cfg.env.dataset, 'obs_horizon', 1)
    pred_horizon = getattr(cfg.env.dataset, 'pred_horizon', 16)
    action_horizon = getattr(cfg.env.dataset, 'action_horizon', 8)
    action_dim = getattr(cfg.env.dataset, 'action_dim', 20)
    
    # Add undo_transform_action function from flow_mimic.py
    def undo_transform_action(action):
        """Convert 20-dim action to 14-dim action for dual arm robots."""
        raw_shape = action.shape
        # print(f"DEBUG: undo_transform_action input shape: {raw_shape}, action: {action}")
        
        if raw_shape[-1] == 20:
            # dual arm - reshape to (batch, 2, 10) where 2 is number of arms, 10 is action per arm
            action = action.reshape(-1, 2, 10)
            # print(f"DEBUG: reshaped to (batch, 2, 10): {action.shape}")
            
            # For each arm, extract position (3), rotation (6), and gripper (1)
            # This is a simplified version - in practice you might need more sophisticated rotation handling
            pos = action[..., :, :3]  # (batch, 2, 3) - position for each arm
            rot = action[..., :, 3:9]  # (batch, 2, 6) - rotation for each arm  
            gripper = action[..., :, 9:10]  # (batch, 2, 1) - gripper for each arm
            
            # print(f"DEBUG: pos shape: {pos.shape}, rot shape: {rot.shape}, gripper shape: {gripper.shape}")
            
            # For now, let's just take the first 14 dimensions as a simple approach
            # This is a temporary fix - we need to understand the exact mapping
            uaction = action[0, 0, :7]  # Take first 7 dimensions from first arm
            uaction = np.concatenate([uaction, action[0, 1, :7]])  # Add first 7 from second arm
            # print(f"DEBUG: final output shape: {uaction.shape}")
        else:
            uaction = action
        
        return uaction
    
    print(f"Mimic test parameters: obs_horizon={obs_horizon}, pred_horizon={pred_horizon}, action_horizon={action_horizon}, action_dim={action_dim}")
    
    for epoch in range(test_episodes):
        seed = test_start_seed + epoch
        env.seed(seed)
        
        total_rewards = 0
        
        for pp in range(test_runs_per_episode):
            obs = env.reset()
            
            obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
            rewards = []
            done = False
            step_idx = 0
            
            with tqdm(total=max_steps, desc=f"Eval Mimic Epoch{epoch}-Run{pp}") as pbar:
                while not done:
                    x_img = np.stack([x for x in obs_deque])
                    x_img = torch.from_numpy(x_img).to(cfg.device, dtype=torch.float32)
                    
                    # Infer action using Flow Matching
                    with torch.no_grad():
                        obs_cond = x_img.flatten(start_dim=1)
                        
                        timehorizon = 1  # Mimic uses single step prediction
                        for i in range(timehorizon):
                            noise = torch.rand(1, pred_horizon, action_dim).to(cfg.device)
                            x0 = noise.expand(x_img.shape[0], -1, -1)
                            timestep = torch.tensor([i / timehorizon]).to(cfg.device)
                            
                            if i == 0:
                                vt = model['noise_pred_net'](x0, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + x0)
                            else:
                                vt = model['noise_pred_net'](traj, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + traj)
                        
                        naction = traj.detach().to('cpu').numpy()
                        naction = naction[0]
                        action_pred = naction
                        
                        # Take action_horizon number of actions
                        start = obs_horizon - 1
                        end = start + action_horizon
                        action = action_pred[start:end, :]
                        
                        for j in range(len(action)):
                            # Convert 20-dim action to 14-dim action for dual arm robots
                            env_action = undo_transform_action(action[j])
                            obs, reward, done, info = env.step(env_action)
                            obs_deque.append(obs)
                            rewards.append(reward)
                            
                            step_idx += 1
                            pbar.update(1)
                            pbar.set_postfix(reward=reward)
                            
                            if step_idx > max_steps or reward == 1:
                                done = True
                            if done:
                                break
            
            episode_reward = sum(rewards)
            total_rewards += episode_reward
            print(f"Episode {seed}-{pp} completed. Total reward: {episode_reward}")
        
        # Record metrics for this test
        total_rewards_list.append(total_rewards)
        print(f"\nTest {epoch+1} Summary:")
        print(f"Total Episodes: {test_runs_per_episode}")
        print(f"Total Rewards: {total_rewards_list[-1]}")
        print(f"Average Reward per Episode: {total_rewards_list[-1] / test_runs_per_episode:.2f}\n")
    
    return success_rates, total_rewards_list, reward_rates


def _test_generic(model, env, test_start_seed, test_episodes, test_runs_per_episode, max_steps, cfg):
    """Generic test function for unknown environment types."""
    import numpy as np
    import collections
    from tqdm import tqdm
    
    success_rates = []
    total_rewards_list = []
    reward_rates = []
    
    # Generic parameters
    obs_horizon = getattr(cfg, 'obs_horizon', 1)
    pred_horizon = getattr(cfg, 'pred_horizon', 16)
    action_horizon = getattr(cfg, 'action_horizon', 8)
    action_dim = getattr(cfg, 'action_dim', 9)
    
    print(f"Generic test parameters: obs_horizon={obs_horizon}, pred_horizon={pred_horizon}, action_horizon={action_horizon}, action_dim={action_dim}")
    
    for epoch in range(test_episodes):
        seed = test_start_seed + epoch
        if hasattr(env, 'seed'):
            env.seed(seed)
        
        total_rewards = 0
        
        for pp in range(test_runs_per_episode):
            try:
                obs = env.reset()
            except:
                obs = env.reset()
            
            obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
            rewards = []
            done = False
            step_idx = 0
            
            with tqdm(total=max_steps, desc=f"Eval Generic Epoch{epoch}-Run{pp}") as pbar:
                while not done:
                    x_img = np.stack([x for x in obs_deque])
                    x_img = torch.from_numpy(x_img).to(cfg.device, dtype=torch.float32)
                    
                    # Infer action using Flow Matching
                    with torch.no_grad():
                        obs_cond = x_img.flatten(start_dim=1)
                        
                        timehorizon = 1
                        for i in range(timehorizon):
                            noise = torch.rand(1, pred_horizon, action_dim).to(cfg.device)
                            x0 = noise.expand(x_img.shape[0], -1, -1)
                            timestep = torch.tensor([i / timehorizon]).to(cfg.device)
                            
                            if i == 0:
                                vt = model(x0, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + x0)
                            else:
                                vt = model(traj, timestep, global_cond=obs_cond)
                                traj = (vt * 1 / timehorizon + traj)
                        
                        naction = traj.detach().to('cpu').numpy()
                        naction = naction[0]
                        action_pred = naction
                        
                        # Take action_horizon number of actions
                        start = obs_horizon - 1
                        end = start + action_horizon
                        action = action_pred[start:end, :]
                        
                        for j in range(len(action)):
                            try:
                                obs, reward, done, info = env.step(action[j])
                            except:
                                obs, reward, done, _, info = env.step(action[j])
                            
                            obs_deque.append(obs)
                            rewards.append(reward)
                            
                            step_idx += 1
                            pbar.update(1)
                            pbar.set_postfix(reward=reward)
                            
                            if step_idx > max_steps:
                                done = True
                            if done:
                                break
            
            episode_reward = sum(rewards)
            total_rewards += episode_reward
            print(f"Episode {seed}-{pp} completed. Total reward: {episode_reward}")
        
        # Record metrics for this test
        total_rewards_list.append(total_rewards)
        print(f"\nTest {epoch+1} Summary:")
        print(f"Total Episodes: {test_runs_per_episode}")
        print(f"Total Rewards: {total_rewards_list[-1]}")
        print(f"Average Reward per Episode: {total_rewards_list[-1] / test_runs_per_episode:.2f}\n")
    
    return success_rates, total_rewards_list, reward_rates


@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig):
    """Main function with Hydra configuration."""
    print(f"Starting Flow Matching experiment with Hydra configuration...")
    
    # Set random seed
    set_random_seed(cfg.seed)
    
    # Create Lightning module
    from utils.lightning_module import create_lightning_module
    model = create_lightning_module(cfg)
    
    # Print experiment info
    print_experiment_info(model.experiment, cfg)

    print(f"Starting experiment for environment: {cfg.env.name}")

    
    # Create trainer with multiple loggers (TensorBoard and Wandb)
    loggers = [
        pl.loggers.TensorBoardLogger(save_dir="logs/")
        # pl.loggers.WandbLogger(
        #     save_dir="logs/", 
        #     offline=True,
        #     project="flow-matching"
        # )
    ]
    
    trainer = pl.Trainer(
        accelerator="gpu",
        devices=cfg.training.get("gpus", 1),
        strategy="ddp" if cfg.training.get("gpus", 1) > 1 else "auto",
        max_epochs=cfg.training.epochs,
        logger=loggers,
        log_every_n_steps=10,  # Log every 10 batches (since we have 49 batches)
        callbacks=[
            pl.callbacks.ModelCheckpoint(
                # dirpath=cfg.checkpoint.save_dir,
                # filename=f"{cfg.env.name}-{{epoch}}",
                filename=f"{cfg.env.name}-{{epoch}}",
                save_top_k=-1,
                # auto_insert_metric_name=False,
                every_n_epochs=cfg.training.save_interval
            )
        ]
    )
    
    # Manually refresh the logger to ensure hparams are saved
    # trainer.logger.save()
    
    # Run based on execution mode
    mode = cfg.execution.mode
    
    ckpt_path = cfg.checkpoint.load_checkpoint

    if mode in ["fm_train", "mle_finetune", "res_finetune"]:
        if ckpt_path:
            print(f"Resuming training from checkpoint: {ckpt_path}")
            trainer.fit(model, ckpt_path=ckpt_path)
        else:
            print("Starting training from scratch.")
            trainer.fit(model)
    elif mode == "test":
        if not ckpt_path:
            raise ValueError("Checkpoint path must be specified for testing.")
        print(f"Starting testing from checkpoint: {ckpt_path}")
        trainer.test(model, ckpt_path=ckpt_path)
    else:
        raise ValueError(f"Unknown execution mode: {mode}")
    
    print("Experiment completed successfully!")



if __name__ == "__main__":
    main()
