# benchrl/environments/wrappers/observation_wrappers.py

import numpy as np
import gymnasium as gym
from collections import deque
from typing import Union, Tuple, Dict, Any

from .base_wrapper import ObservationWrapper


class NormalizeObservation(ObservationWrapper):
    """
    Normalize observations using running statistics.
    
    This wrapper will normalize observations s.t. each coordinate is centered with 
    zero mean and unit variance.
    """
    
    def __init__(
        self,
        env: gym.Env,
        epsilon: float = 1e-8,
        clip: float = 10.0
    ):
        super().__init__(env)
        self.epsilon = epsilon
        self.clip = clip
        
        # Running statistics
        self.obs_mean = np.zeros(env.observation_space.shape, dtype=np.float64)
        self.obs_var = np.ones(env.observation_space.shape, dtype=np.float64)
        self.count = 0
        
        # Update observation space
        self.observation_space = gym.spaces.Box(
            low=-clip,
            high=clip,
            shape=env.observation_space.shape,
            dtype=np.float32
        )
    
    def observation(self, obs):
        """Normalize the observation."""
        # Update running statistics
        self._update_stats(obs)
        
        # Normalize
        normalized_obs = (obs - self.obs_mean) / np.sqrt(self.obs_var + self.epsilon)
        normalized_obs = np.clip(normalized_obs, -self.clip, self.clip)
        
        return normalized_obs.astype(np.float32)
    
    def _update_stats(self, obs):
        """Update running mean and variance using Welford's algorithm."""
        self.count += 1
        delta = obs - self.obs_mean
        self.obs_mean += delta / self.count
        self.obs_var += delta * (obs - self.obs_mean) - self.obs_var / self.count


class RescaleObservation(ObservationWrapper):
    """
    Rescale observations to a new range.
    
    This wrapper rescales observations from the original range to a specified new range.
    """
    
    def __init__(self, env: gym.Env, new_min: float = 0.0, new_max: float = 1.0):
        super().__init__(env)
        self.new_min = new_min
        self.new_max = new_max
        self.old_min = env.observation_space.low.min()
        self.old_max = env.observation_space.high.max()
        
        # obs_shape = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=new_min,
            high=new_max,
            shape=env.observation_space.shape,
            dtype=np.float32
        )
    
    def observation(self, obs):
        """Rescale the observation."""
        
        # Rescale using linear transformation
        rescaled_obs = (obs - self.old_min) / (self.old_max - self.old_min) * (self.new_max - self.new_min) + self.new_min
        return rescaled_obs.astype(self.observation_space.dtype)

class FrameStackWrapper(ObservationWrapper):
    """
    Stack k last frames.
    
    Returns lazy array, which is much more memory efficient.
    """
    
    def __init__(self, env: gym.Env, k: int):
        super().__init__(env)
        self.k = k
        self.frames = deque([], maxlen=k)
        
        obs_shape = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=np.repeat(env.observation_space.low, k, axis=0),
            high=np.repeat(env.observation_space.high, k, axis=0),
            shape=(k,) + obs_shape,
            dtype=env.observation_space.dtype
        )
    
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.k):
            self.frames.append(obs)
        return self._get_obs(), info
    
    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)

        self.frames.append(obs)
        return self._get_obs(), reward, terminated, truncated, info
    
    def _get_obs(self):
        assert len(self.frames) == self.k
        return LazyFrames(list(self.frames))
    
    def observation(self, obs):
        # Not used as we override reset and step
        return obs


class LazyFrames:
    """
    Lazy frame stacking to optimize memory usage.
    
    This object ensures that common frames between observations are only stored once.
    """
    
    def __init__(self, frames):
        self._frames = frames
        self._out = None
    
    def _force(self):
        if self._out is None:
            self._out = np.concatenate(self._frames, axis=0)
            self._frames = None
        return self._out
    
    def __array__(self, dtype=None):
        out = self._force()
        if dtype is not None:
            out = out.astype(dtype)
        return out
    
    def __len__(self):
        return len(self._force())
    
    def __getitem__(self, i):
        return self._force()[i]


class GrayScaleObservation(ObservationWrapper):
    """
    Convert RGB observations to grayscale.
    """
    
    def __init__(self, env: gym.Env, keep_dim: bool = False):
        super().__init__(env)
        self.keep_dim = keep_dim
        
        obs_shape = self.observation_space.shape
        assert len(obs_shape) == 3 and obs_shape[-1] == 3, \
            "GrayScaleObservation only works with RGB images"
        
        if keep_dim:
            self.observation_space = gym.spaces.Box(
                low=0, high=255,
                shape=obs_shape[:-1] + (1,),
                dtype=np.uint8
            )
        else:
            self.observation_space = gym.spaces.Box(
                low=0, high=255,
                shape=obs_shape[:-1],
                dtype=np.uint8
            )
    
    def observation(self, obs):
        """Convert observation to grayscale."""
        # RGB to grayscale formula: 0.299*R + 0.587*G + 0.114*B
        gray = np.dot(obs[..., :3], [0.299, 0.587, 0.114])
        gray = gray.astype(np.uint8)
        
        if self.keep_dim:
            gray = np.expand_dims(gray, axis=-1)
        
        return gray


class ResizeObservation(ObservationWrapper):
    """
    Resize image observations to a new shape.
    """
    
    def __init__(self, env: gym.Env, shape: Union[int, Tuple[int, int]]):
        super().__init__(env)
        
        if isinstance(shape, int):
            shape = (shape, shape)
        self.shape = shape
        
        obs_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=self.observation_space.low.min(),
            high=self.observation_space.high.max(),
            shape=shape + obs_shape[2:],
            dtype=self.observation_space.dtype
        )
    
    def observation(self, obs):
        """Resize observation."""
        import cv2
        return cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA)


class FlattenObservation(ObservationWrapper):
    """
    Flatten observations.
    """
    
    def __init__(self, env: gym.Env):
        super().__init__(env)
        self.observation_space = gym.spaces.flatten_space(env.observation_space)
    
    def observation(self, obs):
        """Flatten observation."""
        return gym.spaces.flatten(self.env.observation_space, obs)


class DtypeObservation(ObservationWrapper):
    """
    Convert observations to a specific dtype.
    """
    
    def __init__(self, env: gym.Env, dtype: np.dtype):
        super().__init__(env)
        self.dtype = dtype
        
        self.observation_space = gym.spaces.Box(
            low=env.observation_space.low,
            high=env.observation_space.high,
            shape=env.observation_space.shape,
            dtype=dtype
        )
    
    def observation(self, obs):
        """Convert observation dtype."""
        return obs.astype(self.dtype)


class AddChannelDimension(ObservationWrapper):
    """
    Add a channel dimension to observations.
    
    Useful for converting 2D observations to 3D (H, W) -> (H, W, 1).
    """
    
    def __init__(self, env: gym.Env, axis: int = -1):
        super().__init__(env)
        self.axis = axis
        
        obs_shape = list(env.observation_space.shape)
        obs_shape.insert(axis, 1)
        
        self.observation_space = gym.spaces.Box(
            low=env.observation_space.low.min(),
            high=env.observation_space.high.max(),
            shape=tuple(obs_shape),
            dtype=env.observation_space.dtype
        )
    
    def observation(self, obs):
        """Add channel dimension."""
        return np.expand_dims(obs, axis=self.axis)


class PermuteObservation(ObservationWrapper):
    """
    Permute observation dimensions.
    
    Useful for converting between different channel conventions (CHW <-> HWC).
    """
    
    def __init__(self, env: gym.Env, permutation: Tuple[int, ...]):
        super().__init__(env)
        self.permutation = permutation
        
        obs_shape = env.observation_space.shape
        permuted_shape = tuple(obs_shape[i] for i in permutation)
        
        self.observation_space = gym.spaces.Box(
            low=env.observation_space.low.min(),
            high=env.observation_space.high.max(),
            shape=permuted_shape,
            dtype=env.observation_space.dtype
        )
    
    def observation(self, obs):
        """Permute observation dimensions."""
        return np.transpose(obs, self.permutation)


class ObservationDictToArray(ObservationWrapper):
    """
    Convert Dict observation space to array.
    
    Flattens and concatenates all values in the observation dictionary.
    """
    
    def __init__(self, env: gym.Env, keys_to_keep: Union[None, Tuple[str, ...]] = None):
        super().__init__(env)
        
        assert isinstance(env.observation_space, gym.spaces.Dict), \
            "ObservationDictToArray only works with Dict observation spaces"
        
        self.keys_to_keep = keys_to_keep
        if keys_to_keep is None:
            self.keys_to_keep = sorted(env.observation_space.spaces.keys())
        
        # Calculate flattened size
        flat_size = 0
        for key in self.keys_to_keep:
            space = env.observation_space[key]
            flat_size += gym.spaces.flatdim(space)
        
        self.observation_space = gym.spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(flat_size,),
            dtype=np.float32
        )
    
    def observation(self, obs):
        """Convert dict observation to array."""
        obs_list = []
        for key in self.keys_to_keep:
            if key in obs:
                val = obs[key]
                if isinstance(val, np.ndarray):
                    obs_list.append(val.flatten())
                else:
                    obs_list.append(np.array([val]))
        
        return np.concatenate(obs_list).astype(np.float32)


class SkipObservation(ObservationWrapper):
    """
    Skip observations by only returning every nth observation.
    
    Intermediate observations are discarded.
    """
    
    def __init__(self, env: gym.Env, skip: int = 4):
        super().__init__(env)
        self.skip = skip
        self._obs_buffer = None
    
    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self._obs_buffer = obs
        return obs, info
    
    def step(self, action):
        total_reward = 0.0
        terminated = False
        truncated = False
        info = {}
        
        for _ in range(self.skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            if terminated or truncated:
                break
        
        self._obs_buffer = obs
        return obs, total_reward, terminated, truncated, info
    
    def observation(self, obs):
        # Not used as we override reset and step
        return obs


# Utility function to chain multiple observation wrappers
def apply_observation_wrappers(env: gym.Env, wrappers_config: list) -> gym.Env:
    """
    Apply a chain of observation wrappers to an environment.
    
    Args:
        env: Base environment
        wrappers_config: List of wrapper configurations
            Each item should be a dict with '_target_' key and wrapper parameters
    
    Returns:
        Wrapped environment
    """
    from hydra.utils import instantiate
    
    for wrapper_cfg in wrappers_config:
        env = instantiate(wrapper_cfg, env=env)
    
    return env



import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union, Literal, Tuple, List
import gymnasium as gym

from .base_wrapper import ObservationWrapper


class FGSMAttackWrapper(ObservationWrapper):
    """
    Applies the Fast Gradient Sign Method (FGSM) attack to environment observations.
    
    This wrapper modifies observations to potentially cause suboptimal actions by using
    the gradient of a surrogate loss function. Supports both value-based (DQN) and
    policy-based (PPO) agents.
    
    For DQN: Uses cross-entropy between softmax Q-values and the original best action
    For PPO: Uses the negative log probability of the best action under the current policy
    
    Args:
        env: The environment to wrap
        agent_model: The agent's model (Q-network for DQN, actor network for PPO)
        epsilon: Attack strength (magnitude of perturbation)
        norm_type: Norm constraint ('inf' for L-infinity, 'l2' for L2 norm)
        device: Device for computations ('cpu' or 'cuda')
        model_type: Type of model ('q_network' for DQN, 'policy' for PPO)
        track_perturbations: Whether to track perturbation history
    """
    
    def __init__(
        self,
        env: Union[gym.Env, gym.vector.sync_vector_env.SyncVectorEnv],
        agent_model: nn.Module,
        epsilon: float = 0.01,
        norm_type: Literal['inf', 'l2'] = 'inf',
        device: Union[str, torch.device] = 'cpu',
        model_type: Literal['q_network', 'policy'] = 'q_network',
        track_perturbations: bool = False
    ):  
        super().__init__(env)
        
        self.agent_model = agent_model.actor
        self.epsilon = epsilon
        self.norm_type = norm_type.lower()
        self.device = torch.device(device)
        self.model_type = model_type.lower()
        self.track_perturbations = track_perturbations
        
        # Tracking variables
        self.perturbation_history = []
        self.observation_history = []
        self.current_perturbation = None
        
        # Validate inputs
        if self.model_type not in ['q_network', 'policy']:
            raise ValueError(f"Unsupported model_type: {model_type}. Choose 'q_network' or 'policy'.")
        
        if self.norm_type not in ['inf', 'l2']:
            raise ValueError(f"Unsupported norm_type: {norm_type}. Choose 'inf' or 'l2'.")
        
        # Ensure observation space is Box
        if not isinstance(self.observation_space, gym.spaces.Box):
            raise TypeError("FGSMAttackWrapper only supports Box observation spaces.")
        
        # Move model to device and set to eval mode
        self.agent_model.to(self.device)
        self.agent_model.eval()
        
        # Store observation bounds for clipping
        self.obs_low = torch.tensor(self.observation_space.low, dtype=torch.float32, device=self.device)
        self.obs_high = torch.tensor(self.observation_space.high, dtype=torch.float32, device=self.device)
        
        # Handle infinite bounds
        self.clip_low = torch.where(
            torch.isneginf(self.obs_low),
            torch.tensor(-1e10, device=self.device),
            self.obs_low
        )
        self.clip_high = torch.where(
            torch.isposinf(self.obs_high),
            torch.tensor(1e10, device=self.device),
            self.obs_high
        )
    
    def step(self, action):
        """
        Execute one step in the environment and apply FGSM attack to the observation.
        
        Args:
            action: The action to take in the environment
            
        Returns:
            observation: The adversarially perturbed observation
            reward: The reward from the environment
            terminated: Whether the episode has terminated
            truncated: Whether the episode was truncated
            info: Additional information from the environment
        """
        # Take step in the underlying environment
        obs, reward, terminated, truncated, info = self.env.step(action)
        
        # Apply FGSM attack to the observation
        adversarial_obs = self.observation(obs)
        # print(f"{obs=} , {adversarial_obs=}")
        
        return adversarial_obs, reward, terminated, truncated, info
    
    def observation(self, obs: np.ndarray) -> np.ndarray:
        """Apply FGSM attack to the observation."""
        adv_obs, perturbation = self._apply_fgsm(obs)
        
        if self.track_perturbations:
            self.observation_history.append(obs.copy())
            self.perturbation_history.append(perturbation)
            self.current_perturbation = perturbation
        
        return adv_obs
    
    def _compute_loss(self, model_output: torch.Tensor) -> torch.Tensor:
        """
        Compute the appropriate loss based on the model type.
        
        For Q-networks: Cross-entropy between softmax Q-values and best action
        For Policy networks: Negative log probability of the most likely action
        """
        if self.model_type == 'q_network':
            # DQN: Use cross-entropy loss on softmax Q-values
            q_values = model_output
            
            # Get the best action (highest Q-value)
            best_action_idx = torch.argmax(q_values, dim=1)
            
            # Create target distribution (one-hot encoding of best action)
            num_actions = q_values.shape[1]
            target_dist = F.one_hot(best_action_idx, num_classes=num_actions).float()
            
            # Compute softmax over Q-values
            q_probs = F.softmax(q_values, dim=1)
            
            # Cross-entropy loss
            loss = F.cross_entropy(q_probs, target_dist)
            
        elif self.model_type == 'policy':
            # PPO: Use negative log probability of the most likely action
            logits = model_output
            
            # Get action probabilities
            action_probs = F.softmax(logits, dim=1)
            
            # Get the most likely action
            best_action_idx = torch.argmax(action_probs, dim=1)
            
            # Compute log probabilities
            log_probs = F.log_softmax(logits, dim=1)
            
            # Select log probability of best action
            best_action_log_prob = log_probs.gather(1, best_action_idx.unsqueeze(1))
            
            # Negative log likelihood (we want to minimize the probability of good actions)
            loss = -best_action_log_prob.mean()
        
        return loss
    
    def _apply_fgsm(self, obs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Apply FGSM attack to a single observation.
        
        Returns:
            Tuple of (adversarial_observation, perturbation)
        """
        # Convert observation to tensor
        obs_tensor = torch.as_tensor(obs, dtype=torch.float32).to(self.device)
        
        # Add batch dimension if needed
        if obs_tensor.dim() == len(self.observation_space.shape):
            obs_tensor = obs_tensor.unsqueeze(0)
        
        obs_tensor.requires_grad = True
        
        # Forward pass through model
        with torch.enable_grad():
            model_output = self.agent_model(obs_tensor)
            
            # Compute loss
            loss = self._compute_loss(model_output)
            
            # Compute gradients
            self.agent_model.zero_grad()
            loss.backward()
        
        # Check if gradient exists
        if obs_tensor.grad is None:
            # No gradient, return original observation
            print("Warning: No gradient computed. Returning original observation.")
            return obs, np.zeros_like(obs)
        
        grad = obs_tensor.grad.data
        
        # Apply perturbation based on norm type
        if self.norm_type == 'inf':
            # L-infinity norm: epsilon * sign(gradient)
            perturbation = self.epsilon * grad.sign()
        
        elif self.norm_type == 'l2':
            # L2 norm: epsilon * gradient / ||gradient||_2
            grad_flat = grad.view(grad.shape[0], -1)
            grad_norm = torch.norm(grad_flat, p=2, dim=1, keepdim=True)
            grad_norm = grad_norm.view(grad.shape[0], *([1] * (grad.dim() - 1)))
            
            # Avoid division by zero
            perturbation = self.epsilon * grad / (grad_norm + 1e-10)
        
        # Create adversarial 
        adv_obs_tensor = obs_tensor.detach() - perturbation
        
        # Clip to valid range
        adv_obs_tensor = torch.clamp(adv_obs_tensor, self.clip_low, self.clip_high)
        
        # Convert back to numpy
        adv_obs = adv_obs_tensor.cpu().numpy().squeeze(0)
        perturbation_np = perturbation.cpu().numpy().squeeze(0)
        
        # Ensure correct dtype
        adv_obs = adv_obs.astype(self.observation_space.dtype)
        
        return adv_obs, perturbation_np
    
    def reset(self, **kwargs):
        """Reset environment and clear tracking if enabled."""
        obs, info = super().reset(**kwargs)
        
        if self.track_perturbations:
            self.perturbation_history = []
            self.observation_history = []
            self.current_perturbation = None
        
        return obs, info
    
    
    def get_attack_statistics(self) -> dict:
        """Get statistics about the attacks performed."""
        if not self.track_perturbations or not self.perturbation_history:
            return {}
        
        perturbations = np.array(self.perturbation_history)
        
        stats = {
            'num_attacks': len(perturbations),
            'mean_perturbation': np.mean(np.abs(perturbations)),
            'max_perturbation': np.max(np.abs(perturbations)),
            'min_perturbation': np.min(np.abs(perturbations)),
            'std_perturbation': np.std(perturbations)
        }
        
        if self.norm_type == 'inf':
            stats['mean_l_inf'] = np.mean(np.max(np.abs(perturbations), axis=tuple(range(1, perturbations.ndim))))
        elif self.norm_type == 'l2':
            stats['mean_l2'] = np.mean(np.linalg.norm(perturbations.reshape(len(perturbations), -1), axis=1))
        
        return stats
    
    def visualize_attack(self, save_path: Optional[str] = None):
        """Visualize the attack perturbations (for image observations)."""
        if not self.track_perturbations or not self.perturbation_history:
            print("No perturbation history to visualize. Enable track_perturbations=True")
            return
        
        try:
            import matplotlib.pyplot as plt
        except ImportError:
            print("Matplotlib required for visualization. Install with: pip install matplotlib")
            return
        
        # Get a sample of perturbations
        num_samples = min(9, len(self.perturbation_history))
        indices = np.linspace(0, len(self.perturbation_history) - 1, num_samples, dtype=int)
        
        fig, axes = plt.subplots(3, num_samples, figsize=(num_samples * 3, 9))
        
        if num_samples == 1:
            axes = axes.reshape(-1, 1)
        
        for i, idx in enumerate(indices):
            obs = self.observation_history[idx]
            pert = self.perturbation_history[idx]
            
            # Original observation
            if obs.ndim == 3 and obs.shape[-1] in [1, 3, 4]:
                # Image observation (H, W, C)
                axes[0, i].imshow(obs.squeeze(), cmap='gray' if obs.shape[-1] == 1 else None)
            else:
                # Non-image observation
                axes[0, i].plot(obs.flatten())
            axes[0, i].set_title(f'Original (t={idx})')
            axes[0, i].axis('off' if obs.ndim >= 2 else 'on')
            
            # Perturbation
            if pert.ndim >= 2:
                im = axes[1, i].imshow(pert.squeeze(), cmap='RdBu', vmin=-self.epsilon, vmax=self.epsilon)
            else:
                axes[1, i].plot(pert.flatten())
            axes[1, i].set_title('Perturbation')
            axes[1, i].axis('off' if pert.ndim >= 2 else 'on')
            
            # Adversarial observation
            adv_obs = obs + pert
            if adv_obs.ndim == 3 and adv_obs.shape[-1] in [1, 3, 4]:
                axes[2, i].imshow(adv_obs.squeeze(), cmap='gray' if adv_obs.shape[-1] == 1 else None)
            else:
                axes[2, i].plot(adv_obs.flatten())
            axes[2, i].set_title('Adversarial')
            axes[2, i].axis('off' if adv_obs.ndim >= 2 else 'on')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
        else:
            plt.show()
        plt.close()


class PGDAttackWrapper(FGSMAttackWrapper):
    """
    Projected Gradient Descent (PGD) attack wrapper.
    
    Extends FGSM to perform multiple iterations with smaller step size.
    """
    
    def __init__(
        self,
        env: gym.Env,
        agent_model: nn.Module,
        epsilon: float = 0.01,
        step_size: float = 0.003,
        num_steps: int = 10,
        norm_type: Literal['inf', 'l2'] = 'inf',
        device: Union[str, torch.device] = 'cpu',
        model_type: Literal['q_network', 'policy'] = 'q_network',
        track_perturbations: bool = False,
        random_start: bool = True
    ):
        super().__init__(
            env=env,
            agent_model=agent_model,
            epsilon=epsilon,
            norm_type=norm_type,
            device=device,
            model_type=model_type,
            track_perturbations=track_perturbations
        )
        
        self.step_size = step_size
        self.num_steps = num_steps
        self.random_start = random_start
    
    def _apply_fgsm(self, obs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Override to implement PGD attack."""
        # Convert observation to tensor
        obs_tensor = torch.as_tensor(obs, dtype=torch.float32).to(self.device)
        
        # Add batch dimension if needed
        if obs_tensor.dim() == len(self.observation_space.shape):
            obs_tensor = obs_tensor.unsqueeze(0)
        
        # Initialize adversarial example
        adv_obs = obs_tensor.clone()
        
        # Random start within epsilon ball
        if self.random_start:
            if self.norm_type == 'inf':
                random_noise = torch.empty_like(adv_obs).uniform_(-self.epsilon, self.epsilon)
            else:  # l2
                random_noise = torch.randn_like(adv_obs)
                random_noise = random_noise / torch.norm(random_noise.view(adv_obs.shape[0], -1), p=2, dim=1, keepdim=True).view(adv_obs.shape[0], *([1] * (adv_obs.dim() - 1)))
                random_noise = random_noise * self.epsilon
            
            adv_obs = adv_obs + random_noise
            adv_obs = torch.clamp(adv_obs, self.clip_low, self.clip_high)
        
        # PGD iterations
        for _ in range(self.num_steps):
            adv_obs.requires_grad = True
            
            # Forward pass
            with torch.enable_grad():
                model_output = self.agent_model(adv_obs)
                loss = self._compute_loss(model_output)
                
                # Compute gradients
                self.agent_model.zero_grad()
                loss.backward()
            
            grad = adv_obs.grad.data
            
            # Apply step
            if self.norm_type == 'inf':
                adv_obs = adv_obs.detach() + self.step_size * grad.sign()
                
                # Project back to epsilon ball
                perturbation = adv_obs - obs_tensor
                perturbation = torch.clamp(perturbation, -self.epsilon, self.epsilon)
                
            else:  # l2
                grad_flat = grad.view(grad.shape[0], -1)
                grad_norm = torch.norm(grad_flat, p=2, dim=1, keepdim=True)
                grad_norm = grad_norm.view(grad.shape[0], *([1] * (grad.dim() - 1)))
                
                adv_obs = adv_obs.detach() + self.step_size * grad / (grad_norm + 1e-10)
                
                # Project back to epsilon ball
                perturbation = adv_obs - obs_tensor
                pert_flat = perturbation.view(perturbation.shape[0], -1)
                pert_norm = torch.norm(pert_flat, p=2, dim=1, keepdim=True)
                
                scale = torch.min(torch.ones_like(pert_norm), self.epsilon / (pert_norm + 1e-10))
                scale = scale.view(perturbation.shape[0], *([1] * (perturbation.dim() - 1)))
                
                perturbation = perturbation * scale
            
            adv_obs = obs_tensor + perturbation
            adv_obs = torch.clamp(adv_obs, self.clip_low, self.clip_high)
        
        # Convert back to numpy
        adv_obs_np = adv_obs.cpu().numpy().squeeze(0)
        perturbation_np = (adv_obs - obs_tensor).cpu().numpy().squeeze(0)
        
        # Ensure correct dtype
        adv_obs_np = adv_obs_np.astype(self.observation_space.dtype)
        
        return adv_obs_np, perturbation_np