from dataclasses import dataclass
import numpy as np, time, os, mujoco as mj
import gymnasium as gym
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.envs.mujoco.ant_v4 import AntEnv
from gymnasium.spaces import Box
from gymnasium import utils
from warnings import filterwarnings
from utils import *
from dotenv import load_dotenv
from gymnasium.envs.registration import register
from gymnasium import spaces
from typing import List, Union, Dict


# Load environment variables and filter warnings
load_dotenv()
filterwarnings(action="ignore", category=DeprecationWarning, 
               message="`np.bool8` is a deprecated alias for `np.bool_`")


BASE_XML_DIR = os.getenv('BASE_XML_DIR') 


# General - Observation is raw vector
class MPEnv(gym.Wrapper):
    """
    Environment wrapper that uses manual indices as actions and runs them through the policy network
    """
    def __init__(self, env, policy_network, device, num_actions=7, std=0.2, hard=False):
        self.env = env
        self.policy_network = policy_network
        self.device = device
        from gymnasium import spaces 
        # Create discrete action space for manual indices
        self.action_space = spaces.Discrete(num_actions)
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.steps = 0
        self.std = std
        self.hard = hard
        logger.info(f"{GREEN} hard: {self.hard}, std: {self.std} {ENDC}")
        self.obs = None
        
    def step(self, action):
        """
        Takes a discrete action (manual index), runs it through policy network,
        and executes the resulting continuous action in the environment
        """
        self.steps += 1
        # Convert observation to tensor for policy network
        obs_tensor = torch.as_tensor(
            np.expand_dims(self.obs, axis=0), 
            dtype=torch.float32
        ).to(self.device)
        # Get continuous action from policy network using manual index
        with torch.no_grad():
            a, *_ = self.policy_network(
                obs_tensor,
                deterministic=True if self.hard else False,
                bias_config=None if self.hard else {"std": self.std},
                manual_indices=torch.tensor([action]).to(self.device)
            )
            
            
        # Execute continuous action in environment
        continuous_action = a.cpu().numpy()
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action[0])
        self.obs = observation
        return observation, rewards, terminated, truncated, info 

    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info

# General - Observation is dict
class MPEnvDict(gym.Wrapper):
    """
    Environment wrapper that uses manual indices as actions and runs them through the policy network
    """
    def __init__(self, env, policy_network, device, num_actions=7, std=0.2, hard=False, options=None, rewards_scale=1.0):
        super().__init__(env)  # Call the parent class constructor first
        self.policy_network = policy_network
        self.device = device
        from gymnasium import spaces 
        # Create discrete action space for manual indices
        self.action_space = spaces.Discrete(num_actions)
        self.steps = 0
        self.std = std
        self.hard = hard
        logger.info(f"{GREEN} hard: {self.hard}, std: {self.std} {ENDC}")
        self.obs = None
        self.options = options
        self.rewards_scale = rewards_scale
        self._reward_threshold = 100.0  # Set your desired threshold
    
    # Override the spec property to provide a reward threshold
    @property
    def spec(self):
        # Get the wrapped env's spec
        spec = super().spec
        if spec is None:
            # If no spec exists, you can't modify it, so just return None
            return None
            
        # If the spec doesn't have a reward threshold, return the spec with our threshold
        if spec.reward_threshold is None:
            # Can't modify the spec directly, so we need a workaround
            # Return a dictionary-like object that will respond to .reward_threshold
            class EnhancedSpec:
                def __init__(self, original_spec, reward_threshold):
                    self._original_spec = original_spec
                    self._reward_threshold = reward_threshold
                
                @property
                def reward_threshold(self):
                    return self._reward_threshold
                
                # Forward any other attributes to the original spec
                def __getattr__(self, name):
                    return getattr(self._original_spec, name)
            
            return EnhancedSpec(spec, self._reward_threshold)
        return spec
     
    def step(self, action):
        """
        Takes a discrete action (manual index), runs it through policy network,
        and executes the resulting continuous action in the environment
        """
        self.steps += 1
        # Convert observation to tensor for policy network
        obs_tensor = torch.as_tensor(
            np.expand_dims(self.obs['observation'], axis=0), 
            dtype=torch.float32
        ).to(self.device)
        # Get continuous action from policy network using manual index
        with torch.no_grad():
            a, *_ = self.policy_network(
                obs_tensor,
                deterministic=True if self.hard else False,
                bias_config=None if self.hard else {"std": self.std},
                manual_indices=torch.tensor([action]).to(self.device)
            )
            
        # Execute continuous action in environment
        continuous_action = a.cpu().numpy()
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action[0])
        self.obs = observation
        rewards *= self.rewards_scale
        return observation, rewards, terminated, truncated, info 
        
    def reset(self, **kwargs):
        self.steps = 0
        observation, info = self.env.reset(options=self.options, **kwargs)
        self.obs = observation
        return observation, info

  
class NoCtrlCostWrapper(gym.Wrapper):
    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        reward = info['reward_forward']
        return observation, reward, terminated, truncated, info


def make_no_ctrl_env(env_id: str, **kwargs) -> gym.Env:
    """
    Gym entry-point: create the base env and wrap it.
      - env_id: the original Gym ID, e.g. 'CartPole-v1'
    - kwargs: forwarded to the underlying gym.make()
    """
    base = gym.make(env_id, **kwargs)
    return NoCtrlCostWrapper(base)


class AntMaze(gym.Wrapper):
    metadata = {
        "render_modes": ["human", "rgb_array"],
    }

    def __init__(
        self,
        goal_cell: np.ndarray = np.array([2, 4]),
        rewards_scale: float = 1.0,
        **kwargs
    ):
        # store goal for reset and reward scaling
        self.options = {"goal_cell": goal_cell}
        self.rewards_scale = rewards_scale
        self.prev_dist_to_goal = None

        # exact maze layout
        maze_map = [
            [1, 1, 1, 1, 1, 1, 1, 1, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 1, 1, 1, 0, 1, 1, 1, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 1, 1, 1, 1, 1, 1, 1, 1]
        ]

        base_env = gym.make(
            "AntMaze_UMaze-v5",
            xml_file=f"{BASE_XML_DIR}/ant_maze.xml",
            maze_map=maze_map,
            **kwargs
        )
        super().__init__(base_env)

    def reset(self, **kwargs):
        kwargs.pop("options", None)
        obs, info = self.env.reset(options=self.options, **kwargs)
        ag = obs["achieved_goal"]
        dg = obs["desired_goal"]
        self.prev_dist_to_goal = np.linalg.norm(ag - dg)
        return obs, info

    def step(self, action):
        obs, _, terminated, truncated, info = super().step(action)

        ag = obs["achieved_goal"]
        dg = obs["desired_goal"]
        curr_dist = np.linalg.norm(ag - dg)
        reward = (self.prev_dist_to_goal - curr_dist) * self.rewards_scale

        contact = self.env.env.env.env.data.sensordata[0]
        if contact > 0.0:
            terminated = True
            truncated = True

        self.prev_dist_to_goal = curr_dist
        # print(info)
        return obs, reward, terminated, truncated, info

# Modified MPEnv - Raw observation for policy, action history for outer usage
class MPEnvSequential(gym.Wrapper):
    """
    Environment wrapper that uses manual indices as actions and runs them through the policy network.
    Maintains a history of the last 5 actions for external tracking while using raw observations internally.
    """
    def __init__(self, env, policy_network, device, num_actions=7, std=0.2, hard=False, history_length=10):
        super(MPEnvSequential, self).__init__(env)
        self.env = env
        self.policy_network = policy_network
        self.device = device
        # Create discrete action space for manual indices
        self.action_space = spaces.Discrete(num_actions)
        self.observation_space = Box(
            low=-1,  
            high=num_actions,
            shape=(history_length,),
            dtype=np.float32
        )
        self.metadata = env.metadata
        self.steps = 0
        self.std = std
        self.hard = hard
        logger.info(f"{GREEN} hard: {self.hard}, std: {self.std} {ENDC}")
        self.obs = None
        # Add action history tracking
        self.history_length = history_length
        self.action_history = [-1] * self.history_length  # Initialize with -1s
        
    def step(self, action):
        """
        Takes a discrete action (manual index), runs it through policy network,
        and executes the resulting continuous action in the environment.
        Updates action history for external tracking.
        """
        self.steps += 1
        
        # Update action history
        self.action_history.pop(0)  # Remove oldest action
        self.action_history.append(int(action))  # Add current action
        # Convert observation to tensor for policy network
        obs_tensor = torch.as_tensor(
            np.expand_dims(self.obs, axis=0), 
            dtype=torch.float32
        ).to(self.device)
        # Get continuous action from policy network using manual index
        with torch.no_grad():
            a, *_ = self.policy_network(
                obs_tensor,
                deterministic=True if self.hard else False,
                bias_config=None if self.hard else {"std": self.std},
                manual_indices=torch.tensor([action]).to(self.device)
            )
            
        # Execute continuous action in environment
        continuous_action = a.cpu().numpy()
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action[0])
        self.obs = observation
        
        # Add action history to info dict for external use
        info['action_history'] = self.action_history.copy()
        
        return np.array(self.action_history), rewards, terminated, truncated, info
    
    def reset(self, **kwargs):
        """
        Reset the environment and action history
        """
        self.steps = 0
        self.action_history = [-1] * self.history_length  # Reset action history
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        info['action_history'] = self.action_history.copy()
        
        return np.array(self.action_history), info


class MPEnvContFull(gym.Wrapper):
    """
    Environment wrapper that uses continuous mixing weights as actions.
    SAC will output mixing weights, and we compute weighted average of all component means.
    """
    def __init__(self, env, policy_network, device, n_components=5, std=0.2, hard=False, 
                 normalize_weights=True, min_weight=0.01):
        super().__init__(env)
        self.env = env
        self.policy_network = policy_network
        self.device = device
        self.n_components = n_components
        self.std = std
        self.hard = hard
        self.normalize_weights = normalize_weights
        self.min_weight = min_weight
        
        # Create continuous action space for mixing weights [0, 1]^n_components
        self.action_space = spaces.Box(
            low=0.0, 
            high=1.0, 
            shape=(n_components,), 
            dtype=np.float32
        )
        
        # Keep original observation space
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.hard = True 
        self.steps = 0
        self.obs = None
        
        print(f"MPEnvCont initialized: n_components={n_components}, std={std}, hard={hard}")
        
    def _process_mixing_weights(self, raw_weights):
        """
        Process raw mixing weights from SAC to ensure they're valid probabilities
        """
        # Convert to tensor if needed
        if isinstance(raw_weights, np.ndarray):
            weights = torch.tensor(raw_weights, dtype=torch.float32, device=self.device)
        else:
            weights = raw_weights.to(self.device)
            
        # Ensure weights are positive
        weights = torch.abs(weights)
        
        if self.normalize_weights:
            # Normalize to sum to 1 (probability distribution)
            weights = weights / (weights.sum() + 1e-8)
            
            # Apply minimum weight constraint to prevent any component from being completely ignored
            if self.min_weight > 0:
                weights = torch.clamp(weights, min=self.min_weight)
                weights = weights / weights.sum()  # Renormalize after clamping
        else:
            # Just clamp to [0, 1] range without normalization
            weights = torch.clamp(weights, 0.0, 1.0)
            
        return weights
    
    def _compute_weighted_action(self, obs_tensor, mixing_weights):
        """
        Compute weighted average action using mixing weights and component means
        """
        with torch.no_grad():
            # Get all component parameters from policy network
            _, _, _, _, _, _, info = self.policy_network(
                obs_tensor,
                deterministic=self.hard,
                with_logprob=False,
                bias_config=None if self.hard else {"std": self.std}
            )
            
            # Extract all component means: [B, K, D] -> [K, D] for single batch
            all_mus = info['all_mus'].squeeze(0)  # [K, D]
            all_stds = info['all_stds'].squeeze(0)  # [K, D]
            
            # Compute weighted mean: sum over components
            # mixing_weights: [K], all_mus: [K, D] -> weighted_mu: [D]
            weighted_mu = torch.sum(mixing_weights.unsqueeze(-1) * all_mus, dim=0)
            
            # Optional: also compute weighted std for sampling
            if not self.hard:
                weighted_std = torch.sqrt(
                    torch.sum(mixing_weights.unsqueeze(-1) * (all_stds ** 2), dim=0)
                )
                # Sample from weighted distribution
                from torch.distributions import Normal
                weighted_dist = Normal(weighted_mu, weighted_std)
                pi_action = weighted_dist.rsample()
            else:
                pi_action = weighted_mu
            
            # Apply tanh + scaling (same as original policy)
            continuous_action = torch.tanh(pi_action) * self.policy_network.act_limit
            
            return continuous_action, {
                'all_mus': all_mus,
                'all_stds': all_stds,
                'weighted_mu': weighted_mu,
                'mixing_weights': mixing_weights,
                'pi_action': pi_action
            }
    
    def step(self, action):
        """
        Takes continuous mixing weights, computes weighted action, and executes in environment
        
        Args:
            action: numpy array of shape (n_components,) representing mixing weights
        """
        self.steps += 1
        
        # Convert observation to tensor for policy network
        obs_tensor = torch.as_tensor(
            np.expand_dims(self.obs, axis=0), 
            dtype=torch.float32
        ).to(self.device)
        
        # Process mixing weights
        mixing_weights = self._process_mixing_weights(action)
        
        # Compute weighted continuous action
        continuous_action, action_info = self._compute_weighted_action(obs_tensor, mixing_weights)
        
        # Execute continuous action in environment
        continuous_action_np = continuous_action.cpu().numpy()
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action_np)
        
        # Add action info to environment info for debugging/logging
        info['action_info'] = action_info
        info['raw_mixing_weights'] = action
        info['processed_mixing_weights'] = mixing_weights.cpu().numpy()
        
        self.obs = observation
        return observation, rewards, terminated, truncated, info
    
    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info
   

class MPEnvCont(gym.Wrapper):
    """
    Environment wrapper that uses continuous mixing weights as actions.
    SAC will output mixing weights, and we compute weighted average of all component means.
    """
    def __init__(self, env, policy_network, device, n_components=64):
        super().__init__(env)
        self.env = env
        self.policy_network = policy_network
        self.device = device
        self.n_components = n_components
        
        # Create continuous action space for mixing weights [0, 1]^n_components
        self.action_space = spaces.Box(
            low=-1.0, 
            high=1.0, 
            shape=(n_components,), 
            dtype=np.float32
        )
        
        # Keep original observation space
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.steps = 0
        self.obs = None
        
        print(f"MPEnvCont initialized: n_components={n_components}, std={std}")
        
    def _process_mixing_weights(self, raw_weights):
        """
        Process raw mixing weights from SAC to ensure they're valid probabilities
        """
        # Convert to tensor if needed
        if isinstance(raw_weights, np.ndarray):
            weights = torch.tensor(raw_weights, dtype=torch.float32, device=self.device)
        else:
            weights = raw_weights.to(self.device)
        
        return weights
    
    def _compute_weighted_action(self, obs_tensor, mixing_weights):
        """
        Compute weighted average action using mixing weights and component means
        """
        with torch.no_grad():
            # Get all component parameters from policy network
            _, _, _, _, _, _, info = self.policy_network(
                obs_tensor,
                deterministic=self.hard,
                with_logprob=False,
                bias_config=None 
            )
            
            # Extract all component means: [B, K, D] -> [K, D] for single batch
            all_mus = info['all_mus'].squeeze(0)  # [K, D]
            
            # Compute weighted mean: sum over components
            # mixing_weights: [K], all_mus: [K, D] -> weighted_mu: [D]
            weighted_mu = torch.sum(mixing_weights.unsqueeze(-1) * all_mus, dim=0)
            pi_action = weighted_mu
            
            # Apply tanh + scaling (same as original policy)
            continuous_action = torch.tanh(pi_action) * self.policy_network.act_limit
            
            return continuous_action, {
                'all_mus': all_mus,
                'weighted_mu': weighted_mu,
                'mixing_weights': mixing_weights,
                'pi_action': pi_action
            }
    
    def step(self, action):
        """
        Takes continuous mixing weights, computes weighted action, and executes in environment
        
        Args:
            action: numpy array of shape (n_components,) representing mixing weights
        """
        self.steps += 1
        
        # Convert observation to tensor for policy network
        obs_tensor = torch.as_tensor(
            np.expand_dims(self.obs, axis=0), 
            dtype=torch.float32
        ).to(self.device)
        
        # Process mixing weights
        mixing_weights = self._process_mixing_weights(action)
        
        # Compute weighted continuous action
        continuous_action, action_info = self._compute_weighted_action(obs_tensor, mixing_weights)
        
        # Execute continuous action in environment
        continuous_action_np = continuous_action.cpu().numpy()
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action_np)
        
        # Add action info to environment info for debugging/logging
        info['action_info'] = action_info
        info['raw_mixing_weights'] = action
        info['processed_mixing_weights'] = mixing_weights.cpu().numpy()
        
        self.obs = observation
        return observation, rewards, terminated, truncated, info
    
    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info
 

# RANDOM ACTION SET
class MPEnvBaseline(gym.Wrapper):
    """
    Environment wrapper that creates K random actions and keeps using them. 
    """
    def __init__(self, env, num_actions=64, action_magnitude=0.5):
        super().__init__(env)
        self.action_space = spaces.Discrete(num_actions)
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.steps = 0
        self.obs = None
        self.action_magnitude = action_magnitude
        # Create K random actions
        self.num_actions = num_actions
        action_dim = env.action_space.shape[0]  # Get action space dimension
        
        # Generate K actions by randomly sampling from {-0.5, 0, 0.5} for each dimension
        
        self.actions = []
        for i in range(num_actions):
            action = np.random.choice([-self.action_magnitude, 0, self.action_magnitude], size=action_dim)
            self.actions.append(action)
        
    def step(self, k):
        """
        Execute step using action ID k
        """
        # Validate action ID
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
            
        self.steps += 1
        # Get the action using the provided ID
        continuous_action = self.actions[k]
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action)
        self.obs = observation
        return observation, rewards, terminated, truncated, info 
        
    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info

  
class UniformActionSetEnv(gym.Wrapper):
    """
    Environment wrapper that creates K uniformly sampled actions from the action space bounds.
    Each action dimension is uniformly sampled between the environment's action limits.
    """
    def __init__(self, env, num_actions=64):
        super().__init__(env)
        self.action_space = spaces.Discrete(num_actions)
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.steps = 0
        self.obs = None
        
        # Get action space information
        self.num_actions = num_actions
        action_dim = env.action_space.shape[0]
        
        # Get action limits from the environment's action space
        if hasattr(env.action_space, 'low') and hasattr(env.action_space, 'high'):
            action_low = env.action_space.low
            action_high = env.action_space.high
        else:
           raise ValueError("Environment action space must have 'low' and 'high' attributes") 
        
        # Generate K actions by uniformly sampling from action space bounds
        self.actions = []
        for i in range(num_actions):
            action = np.random.uniform(action_low, action_high)
            self.actions.append(action)
        
    def step(self, k):
        """
        Execute step using action ID k
        """
        # Validate action ID
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
            
        self.steps += 1
        # Get the action using the provided ID
        continuous_action = self.actions[k]
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action)
        self.obs = observation
        return observation, rewards, terminated, truncated, info 
        
    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info

    def get_action(self, k):
        """
        Get the continuous action corresponding to action ID k
        """
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
        return self.actions[k]
    
    def get_all_actions(self):
        """
        Get all available actions
        """
        return np.array(self.actions)

class GaussianActionSetEnv(gym.Wrapper):
    """
    Environment wrapper that creates K Gaussian sampled actions from specified mean and std.
    Each action dimension is sampled from a Gaussian distribution with given mean and standard deviation.
    """
    def __init__(self, env, num_actions=64, mean=None, std=None):
        super().__init__(env)
        self.action_space = spaces.Discrete(num_actions)
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.steps = 0
        self.obs = None
        
        # Get action space information
        self.num_actions = num_actions
        action_dim = env.action_space.shape[0]
        
        # Get action limits from the environment's action space for clipping
        if hasattr(env.action_space, 'low') and hasattr(env.action_space, 'high'):
            self.action_low = env.action_space.low
            self.action_high = env.action_space.high
        else:
            raise ValueError("Environment action space must have 'low' and 'high' attributes")
        
        # Set default mean and std if not provided
        if mean is None:
            # Default to center of action space
            mean = (self.action_low + self.action_high) / 2
        if std is None:
            # Default to 1/6 of action range (so ~99.7% of samples are within bounds)
            std = (self.action_high - self.action_low) / 6
            
        # Ensure mean and std are numpy arrays with correct shape
        self.mean = np.broadcast_to(mean, action_dim)
        self.std = np.broadcast_to(std, action_dim)
        
        # Generate K actions by sampling from Gaussian distribution
        self.actions = []
        for i in range(num_actions):
            action = np.random.normal(self.mean, self.std)
            # Clip to action space bounds
            action = np.clip(action, self.action_low, self.action_high)
            self.actions.append(action)
        
    def step(self, k):
        """
        Execute step using action ID k
        """
        # Validate action ID
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
            
        self.steps += 1
        # Get the action using the provided ID
        continuous_action = self.actions[k]
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action)
        self.obs = observation
        return observation, rewards, terminated, truncated, info 
        
    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info
    
    def get_action(self, k):
        """
        Get the continuous action corresponding to action ID k
        """
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
        return self.actions[k]
    
    def get_all_actions(self):
        """
        Get all available actions
        """
        return np.array(self.actions)


class MPEnvDistribution(gym.Wrapper):
    """
    Environment wrapper that reads CSV files from a directory, builds Gaussian distributions
    from action columns (a_0 to a_7), and samples K actions from these distributions.
    """
    def __init__(self, env, csv_dir: str, num_actions: int = 64):
        super().__init__(env)
        self.action_space = spaces.Discrete(num_actions)
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.steps = 0
        self.obs = None
        
        # Store parameters
        self.num_actions = num_actions
        self.csv_dir = csv_dir
        action_dim = env.action_space.shape[0]  # Get action space dimension
        
        # Read CSV files and build distributions
        self.action_distributions = self._build_action_distributions()
        
        # Validate that we have the right number of dimensions
        if len(self.action_distributions) != action_dim:
            raise ValueError(f"Expected {action_dim} action dimensions, but found {len(self.action_distributions)} in CSV files")
        
        # Sample K actions from the distributions
        self.actions = self._sample_actions_from_distributions()
        
    def _read_csv_files(self) -> pd.DataFrame:
        """
        Read all CSV files from the specified directory and concatenate them.
        """
        if not os.path.exists(self.csv_dir):
            raise ValueError(f"Directory {self.csv_dir} does not exist")
        
        csv_files = [f for f in os.listdir(self.csv_dir) if f.endswith('.csv')]
        
        if not csv_files:
            raise ValueError(f"No CSV files found in directory {self.csv_dir}")
        
        dataframes = []
        for csv_file in csv_files:
            file_path = os.path.join(self.csv_dir, csv_file)
            try:
                df = pd.read_csv(file_path)
                dataframes.append(df)
                # print(f"Read {len(df)} rows from {csv_file}")
            except Exception as e:
                print(f"Warning: Could not read {csv_file}: {e}")
        
        if not dataframes:
            raise ValueError("No valid CSV files could be read")
        
        # Concatenate all dataframes
        combined_df = pd.concat(dataframes, ignore_index=True)
        print(f"Total combined data: {len(combined_df)} rows")
        
        return combined_df
    
    def _build_action_distributions(self) -> List[dict]:
        """
        Build Gaussian distributions for each action dimension from CSV data.
        Returns a list of dictionaries containing mean and std for each dimension.
        """
        # Read all CSV files
        df = self._read_csv_files()
        
        # Find action columns (a_0, a_1, ..., a_7)
        action_columns = [col for col in df.columns if col.startswith('a_') and col[2:].isdigit()]
        action_columns.sort(key=lambda x: int(x.split('_')[1]))  # Sort by dimension number
        
        if not action_columns:
            raise ValueError("No action columns (a_0, a_1, etc.) found in CSV files")
        
        print(f"Found action columns: {action_columns}")
        
        distributions = []
        for col in action_columns:
            # Get all values for this dimension, removing any NaN values
            values = df[col].dropna().values
            
            if len(values) == 0:
                raise ValueError(f"No valid values found for column {col}")
            
            # Calculate Gaussian parameters
            mean = np.mean(values)
            std = np.std(values)
            
            if std == 0:
                print(f"Warning: Standard deviation is 0 for {col}, using small value (1e-6)")
                std = 1e-6
            
            distributions.append({'mean': mean, 'std': std})
            print(f"Dimension {col}: {len(values)} samples, Gaussian(μ={mean:.3f}, σ={std:.3f})")
        
        return distributions
    
    def _sample_actions_from_distributions(self) -> List[np.ndarray]:
        """
        Sample actions from the Gaussian distributions.
        """
        actions = []
        
        for i in range(self.num_actions):
            action = []
            for dist_params in self.action_distributions:
                # Sample from Gaussian distribution for this dimension
                sampled_value = np.random.normal(dist_params['mean'], dist_params['std'])
                action.append(sampled_value)
            
            actions.append(np.array(action))
        
        print(f"Generated {len(actions)} actions from Gaussian distributions")
        return actions
    
    
    def resample_actions(self, num_actions: int = None):
        """
        Generate new actions from the same Gaussian distributions.
        
        Args:
            num_actions: Number of actions to generate. If None, uses self.num_actions
        """
        if num_actions is not None:
            self.num_actions = num_actions
            self.action_space = spaces.Discrete(num_actions)
        
        self.actions = self._sample_actions_from_distributions()
        print(f"Resampled {len(self.actions)} new actions")
    
    def step(self, k: int):
        """
        Execute step using action ID k
        """
        # Validate action ID
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
            
        self.steps += 1
        # Get the action using the provided ID
        continuous_action = self.actions[k]
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action)
        self.obs = observation
        return observation, rewards, terminated, truncated, info 
        
    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info


class MLPEnvNetwork(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: list, output_dim: int):
        super().__init__()
        layers = []
        prev = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            prev = h
        layers.append(nn.Linear(prev, output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


class MPEnvNN(gym.Wrapper):
    """
    Wrapper that uses fixed action‐slot embeddings + MLP to produce
    K diverse, state‐dependent actions in [-1,1]^D each step.
    """
    def __init__(self, env: gym.Env,
                 num_actions: int = 64,
                 embed_dim: int = 16,
                 hidden_dims: list = [128, 128],
                 add_noise: bool = True):
        super().__init__(env)

        # basic wrappers
        self.num_actions = num_actions
        self.action_dim = env.action_space.shape[0]
        self.action_space = spaces.Discrete(self.num_actions)
        self.observation_space = env.observation_space
        self.metadata = env.metadata

        # device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # flatten observation
        obs_shape = env.observation_space.shape
        self.state_dim = int(np.prod(obs_shape))

        # fixed slot embeddings
        self.embeddings = torch.randn(self.num_actions, embed_dim, device=self.device)
        self.embeddings.requires_grad_(False)

        # MLP: [state_dim + embed_dim] → action_dim
        net_in_dim = self.state_dim + embed_dim
        self.net = MLPEnvNetwork(net_in_dim, hidden_dims, self.action_dim).to(self.device)
        self.net.eval()
        for p in self.net.parameters():
            p.requires_grad = False


        # storage
        self.obs = None
        self.steps = 0

    def reset(self, **kwargs):
        self.steps = 0
        obs, info = self.env.reset(**kwargs)
        self.obs = obs
        return obs, info

    def _compute_actions(self) -> np.ndarray:
        """
        Compute K proposals:
          out_i = tanh( MLP([s; e_i]) + noise_i )
        Returns array of shape (K, action_dim).
        """
        # prepare state batch
        s = np.array(self.obs, dtype=np.float32).reshape(1, -1)
        t_state = torch.from_numpy(s).to(self.device)          # (1, state_dim)
        t_states = t_state.repeat(self.num_actions, 1)         # (K, state_dim)

        # concatenate with embeddings
        t_in = torch.cat([t_states, self.embeddings], dim=1)   # (K, state_dim+embed_dim)

        with torch.no_grad():
            means = self.net(t_in)                             # (K, action_dim)
            actions = torch.tanh(means)                        # in [-1,1]

        return actions.cpu().numpy()

    def step(self, action_id: int):
        if not (0 <= action_id < self.num_actions):
            raise ValueError(f"Action ID must be in [0, {self.num_actions-1}]")

        # generate diverse, state‐dependent proposals
        actions = self._compute_actions()
        act = actions[action_id]

        # step through env
        self.steps += 1
        obs, reward, terminated, truncated, info = self.env.step(act)
        self.obs = obs
        return obs, reward, terminated, truncated, info


def setup_hrl_environment(args, render_mode="rgb_array"):
    if args.env.endswith("_baseline"):
        print(args.env, args.env[: -len("_baseline")]) 
        base_env = gym.make(
            args.env[: -len("_baseline")], 
            width=600, 
            height=600,
            render_mode=render_mode
        )
        env = MPEnvBaseline(base_env, num_actions=args.hrl_nc, action_magnitude = args.hrl_action_magnitude)
        return env
    elif args.env.endswith("_uniform"):
        print(args.env) 
        base_env = gym.make(
            args.env[: -len("_uniform")], 
            width=600, 
            height=600, 
            render_mode=render_mode
        )
        env = UniformActionSetEnv(base_env, num_actions=args.hrl_nc)
        return env
    elif args.env.endswith("_gaussian"):
        print(args.env) 
        base_env = gym.make(
            args.env[: -len("_gaussian")], 
            width=600, 
            height=600, 
            render_mode=render_mode
        )
        env = GaussianActionSetEnv(base_env, num_actions=args.hrl_nc, mean=args.hrl_gmean, std=args.hrl_gstd)
        return env
     
    elif args.env.endswith("_dist"):
        print(args.env) 
        base_env = gym.make(
            args.env[: -len("_dist")], 
            width=600, 
            height=600, 
            render_mode=render_mode
        )
        env = MPEnvDistribution(base_env, csv_dir=args.hrl_csv_dir, num_actions=args.hrl_nc)
        return env
    elif args.env.endswith("_nn"):
        print(args.env) 
        base_env = gym.make(
            args.env[: -len("_nn")], 
            width=600, 
            height=600, 
            render_mode=render_mode
        )
        env = MPEnvNN(base_env, num_actions=args.hrl_nc)
        return env
     
    # Build the original policy network
    ac = torch.load(args.path, map_location=device).to(device)
    ac.pi.eval()
    
    if "ant_maze" in args.env:
        base_env = gym.make(
            args.env, 
            width=800, 
            height=800, 
            render_mode=render_mode,
            include_cfrc_ext_in_observation=False,
            continuing_task=False,
            camera_id=0
            )

        env = MPEnvDict(
            base_env, 
            policy_network=ac.pi, 
            device=device, 
            num_actions=args.hrl_nc, 
            std=args.hrl_std, 
            hard=args.hrl_hard,
            rewards_scale=args.hrl_rewards_scale
        )
    elif "Fetch" in args.env:
        print(f"{RED} Fetch env {ENDC}")
        print(f"{RED} env: {args.env} {ENDC}")
        base_env = gym.make(
            args.env, 
            width=600, 
            height=600, 
            render_mode=render_mode
        )
        env = MPEnvDict(
            base_env, 
            policy_network=ac.pi, 
            device=device, 
            num_actions=args.hrl_nc, 
            std=args.hrl_std, 
            hard=args.hrl_hard,
            rewards_scale=args.hrl_rewards_scale
        )
    elif args.env == "Humanoid-v5":
        base_env = gym.make(
            args.env, 
            include_cinert_in_observation=False, 
            include_cvel_in_observation=False,
            include_cfrc_ext_in_observation=False, 
            include_qfrc_actuator_in_observation=False,
            render_mode=render_mode) 
        env = MPEnv(base_env, policy_network=ac.pi, device=device, num_actions=args.hrl_nc, std=args.hrl_std, hard=args.hrl_hard)
    
    elif "_no_ctrl" in args.env:
        print("-"*1000)
        env_name = args.env[: -len("_no_ctrl")]
        base_env = gym.make(
            env_name, 
            width=600, 
            height=600, 
            render_mode=render_mode
        )
        base_env = NoCtrlCostWrapper(base_env) 
        env = MPEnv(base_env, policy_network=ac.pi, device=device, num_actions=args.hrl_nc, std=args.hrl_std, hard=args.hrl_hard)
    elif args.env.endswith("_seq"):
        print(args.env) 
        base_env = gym.make(
            args.env[: -len("_seq")], 
            width=600, 
            height=600, 
            render_mode=render_mode
        )
        env = MPEnvSequential(base_env, policy_network=ac.pi, device=device, num_actions=args.hrl_nc, std=args.hrl_std, hard=args.hrl_hard)
    
    elif args.env.endswith("_cont"):
        print(args.env) 
        base_env = gym.make(
            args.env, 
            width=600, 
            height=600, 
            render_mode=render_mode
        )
        env = MPEnvCont(base_env, policy_network=ac.pi, device=device, num_actions=args.hrl_nc, std=args.hrl_std, hard=args.hrl_hard)
    else:
        print(args.env) 
        base_env = gym.make(
            args.env, 
            width=600, 
            height=600, 
            render_mode=render_mode
        )
        env = MPEnv(base_env, policy_network=ac.pi, device=device, num_actions=args.hrl_nc, std=args.hrl_std, hard=args.hrl_hard)
    return env


if __name__ == "__main__":
    env = gym.make("Ant-v5", width=600, height=600, render_mode="rgb_array") 
    done = False
    obs = env.reset()
    while not done:
        action = env.action_space.sample()  # Replace with your action
        print(action)
        obs, reward, terminated, truncated, info = env.step(action)
        print(f"Reward: {reward}")
        print(f"Terminated: {terminated}")
        print(f"Truncated: {truncated}")
        print(f"Info: {info}")
        print(f"Observation: {obs}")
        done = terminated or truncated