import numpy as np
import torch
from sklearn.cluster import KMeans
from scipy.spatial.distance import pdist, squareform
import warnings

class ImprovedRBFKernel:
    """
    Improved RBF kernel with better numerical stability and adaptive bandwidth.
    """
    def __init__(self, bandwidth=1.0, adaptive=True, bandwidth_scale=0.5, 
                 min_bandwidth=1e-5, max_bandwidth=100.0):
        """
        Initialize the improved RBF kernel.
        
        Args:
            bandwidth: Initial kernel bandwidth
            adaptive: Whether to adapt bandwidth to median distances
            bandwidth_scale: Scaling factor for adaptive bandwidth
            min_bandwidth: Minimum allowed bandwidth value
            max_bandwidth: Maximum allowed bandwidth value
        """
        self.bandwidth = bandwidth
        self.adaptive = adaptive
        self.bandwidth_scale = bandwidth_scale
        self.min_bandwidth = min_bandwidth
        self.max_bandwidth = max_bandwidth
        self.adaptive_bandwidth = None
    
    def _compute_pairwise_distances(self, x, y=None):
        """
        Compute pairwise squared distances with numerical safeguards.
        
        Args:
            x: First set of points of shape (n, d)
            y: Second set of points of shape (m, d), defaults to x
            
        Returns:
            Pairwise squared distances of shape (n, m)
        """
        # Clean inputs from NaN or Inf
        x = np.nan_to_num(x, nan=0.0, posinf=1e10, neginf=-1e10)
        
        if y is not None:
            y = np.nan_to_num(y, nan=0.0, posinf=1e10, neginf=-1e10)
            
            # Compute squared distances
            dists = np.zeros((x.shape[0], y.shape[0]))
            for i in range(x.shape[0]):
                for j in range(y.shape[0]):
                    # Compute distance safely dimension by dimension
                    diff = x[i] - y[j]
                    # Handle potential overflow by using a more stable method
                    if np.max(np.abs(diff)) > 1e5:
                        # For very large differences, use log-sum-exp trick
                        log_diffs = np.log(np.abs(diff) + 1e-10)
                        max_log = np.max(log_diffs)
                        sq_dist = np.exp(max_log*2) * np.sum(np.exp(2*(log_diffs - max_log)))
                    else:
                        # For normal ranges, use standard computation
                        sq_dist = np.sum(diff * diff)
                    
                    dists[i, j] = sq_dist
        else:
            # More efficient computation when y is None (x == y)
            try:
                # Try using scipy's pdist for efficiency
                pairwise = squareform(pdist(x, 'sqeuclidean'))
                dists = np.array(pairwise, dtype=np.float64)
            except Exception:
                # Fallback to direct computation if pdist fails
                n = x.shape[0]
                dists = np.zeros((n, n))
                for i in range(n):
                    for j in range(i+1, n):
                        diff = x[i] - x[j]
                        sq_dist = np.sum(diff * diff)
                        dists[i, j] = sq_dist
                        dists[j, i] = sq_dist
        
        # Ensure distances are non-negative and finite
        dists = np.maximum(0.0, dists)
        dists = np.minimum(dists, 1e10)  # Cap maximum distance
        
        return dists
    
    def evaluate(self, x, y=None):
        """
        Evaluate the kernel matrix K(x, y).
        
        Args:
            x: First set of points of shape (n, d)
            y: Second set of points of shape (m, d), defaults to x
            
        Returns:
            Kernel matrix of shape (n, m)
        """
        # Compute pairwise distances
        squared_dists = self._compute_pairwise_distances(x, y)
        
        # Adapt bandwidth if needed
        if self.adaptive:
            if y is None:
                # Avoid diagonal elements (zero distances)
                distances = squared_dists[~np.eye(squared_dists.shape[0], dtype=bool)]
                if len(distances) > 0:
                    # Use median of non-zero distances for bandwidth
                    med_dist = np.median(np.sqrt(distances))
                    # Scale bandwidth relative to data spread
                    self.adaptive_bandwidth = med_dist * self.bandwidth_scale
                    
                    # Apply dimension scaling (increases bandwidth in higher dimensions)
                    dim = x.shape[1]
                    dim_factor = np.sqrt(dim) / 2
                    self.adaptive_bandwidth *= dim_factor
                    
                    # Ensure bandwidth is within reasonable bounds
                    self.adaptive_bandwidth = np.clip(
                        self.adaptive_bandwidth, 
                        self.min_bandwidth, 
                        self.max_bandwidth
                    )
                else:
                    # Fallback bandwidth if all distances are zero
                    self.adaptive_bandwidth = 1.0
            else:
                # If using different point sets and no adaptive_bandwidth set yet
                if self.adaptive_bandwidth is None:
                    # Use a reasonable default
                    self.adaptive_bandwidth = 1.0
        
        # Use adaptive bandwidth if available
        bandwidth = self.adaptive_bandwidth if self.adaptive and self.adaptive_bandwidth is not None else self.bandwidth
        
        # Compute kernel with numerical stability
        K = np.exp(-squared_dists / (2 * bandwidth))
        
        # Check if kernel matrix has any NaN or Inf
        if not np.all(np.isfinite(K)):
            warnings.warn("Kernel matrix contains NaN or Inf. Using fallback computation.")
            # Use a more stable computation
            K = np.exp(-np.minimum(squared_dists, 1e6) / (2 * bandwidth))
            # Replace any remaining NaN or Inf
            K = np.nan_to_num(K, nan=0.0, posinf=0.0, neginf=0.0)
        
        return K
    
    def gradient(self, x, y=None):
        """
        Compute gradient of the kernel function with respect to x.
        
        Args:
            x: First set of points of shape (n, d)
            y: Second set of points of shape (m, d), defaults to x
            
        Returns:
            Gradient of the kernel with shape (n, m, d) if y is provided,
            otherwise shape (n, n, d)
        """
        if y is None:
            y = x
        
        n, d = x.shape
        m = y.shape[0]
        
        # Initialize gradient array
        grad_K = np.zeros((n, m, d))
        
        # Compute kernel matrix first
        K = self.evaluate(x, y)
        
        # Use adaptive bandwidth if available
        bandwidth = self.adaptive_bandwidth if self.adaptive and self.adaptive_bandwidth is not None else self.bandwidth
        
        # Compute gradient
        for i in range(n):
            for j in range(m):
                # Compute difference vector
                diff = x[i] - y[j]
                
                # Scale by kernel value and bandwidth
                grad_K[i, j] = K[i, j] * (-diff / bandwidth)
        
        # Check for NaN or Inf in gradient
        if not np.all(np.isfinite(grad_K)):
            warnings.warn("Kernel gradient contains NaN or Inf. Using fallback computation.")
            # Replace NaN or Inf values
            grad_K = np.nan_to_num(grad_K, nan=0.0, posinf=0.0, neginf=0.0)
        
        return grad_K


class ImprovedGSWD:
    """
    Improved Generalized Sliced Wasserstein Distance with optimized projections.
    """
    def __init__(self, n_projections=10, projection_method='optimized', 
                 optimization_steps=5, correlation_aware=True,
                 learning_rate=0.01, momentum=0.9):
        """
        Initialize the improved GSWD.
        
        Args:
            n_projections: Number of projection directions
            projection_method: Method for generating projections ('random', 'optimized', 'pca')
            optimization_steps: Number of steps for optimizing projections
            correlation_aware: Whether to make projections correlation-aware
            learning_rate: Learning rate for projection optimization
            momentum: Momentum factor for projection optimization
        """
        self.n_projections = n_projections
        self.projection_method = projection_method
        self.optimization_steps = optimization_steps
        self.correlation_aware = correlation_aware
        self.learning_rate = learning_rate
        self.momentum = momentum
        
        # Storage for projections
        self.projections = None
        self.projection_weights = None
        
        # For correlation estimation
        self.covariance = None
        self.eigenvalues = None
        self.eigenvectors = None
        
        # For tracking training
        self.fitted = False
    
    def _init_projections(self, dim):
        """
        Initialize projection directions.
        
        Args:
            dim: Dimensionality of the space
            
        Returns:
            Array of projection directions
        """
        if self.projection_method == 'random':
            # Random projections
            projections = np.random.randn(self.n_projections, dim)
            # Normalize
            norms = np.sqrt(np.sum(projections**2, axis=1, keepdims=True))
            projections = projections / (norms + 1e-10)
        elif self.projection_method == 'pca':
            # Initialize to identity - will be replaced by PCA directions
            projections = np.eye(dim)[:self.n_projections]
            if self.n_projections > dim:
                # Pad with random directions if needed
                extra = np.random.randn(self.n_projections - dim, dim)
                extra = extra / np.sqrt(np.sum(extra**2, axis=1, keepdims=True))
                projections = np.vstack([projections, extra])
        else:  # 'optimized'
            # Random initialization for optimization
            projections = np.random.randn(self.n_projections, dim)
            # Normalize
            norms = np.sqrt(np.sum(projections**2, axis=1, keepdims=True))
            projections = projections / (norms + 1e-10)
        
        # Initialize weights equally
        weights = np.ones(self.n_projections) / self.n_projections
        
        return projections, weights
    
    def _estimate_covariance(self, samples):
        """
        Estimate covariance matrix from samples with robust regularization.
        
        Args:
            samples: Sample points of shape (n, d)
            
        Returns:
            Covariance matrix of shape (d, d)
        """
        # Remove any NaN or Inf values
        samples = np.nan_to_num(samples, nan=0.0, posinf=1e10, neginf=-1e10)
        
        # Center the data
        centered = samples - np.mean(samples, axis=0, keepdims=True)
        
        # Compute sample covariance
        n_samples = samples.shape[0]
        cov = np.dot(centered.T, centered) / (n_samples - 1)
        
        # Add small regularization to ensure positive-definiteness
        dim = samples.shape[1]
        regularization = 1e-6 * np.eye(dim)
        cov = cov + regularization
        
        # Ensure symmetry (important for eigendecomposition)
        cov = (cov + cov.T) / 2
        
        return cov
    
    def _optimize_projections(self, source, target):
        """
        Optimize projection directions to maximize distance.
        
        Args:
            source: Source distribution samples
            target: Target distribution samples
            
        Returns:
            Optimized projections and weights
        """
        dim = source.shape[1]
        
        # Start with existing projections or initialize new ones
        if self.projections is None or self.projections.shape[1] != dim:
            self.projections, self.projection_weights = self._init_projections(dim)
        
        projections = self.projections.copy()
        weights = self.projection_weights.copy()
        
        # For gradient descent with momentum
        velocity = np.zeros_like(projections)
        
        # Optimize projections
        for step in range(self.optimization_steps):
            # Compute distance for each projection
            distances = np.zeros(self.n_projections)
            gradients = np.zeros_like(projections)
            
            for i in range(self.n_projections):
                # Project points
                proj_dir = projections[i]
                source_proj = np.dot(source, proj_dir)
                target_proj = np.dot(target, proj_dir)
                
                # Compute 1D Wasserstein distance (simplified for efficiency)
                source_proj_sorted = np.sort(source_proj)
                target_proj_sorted = np.sort(target_proj)
                distance = np.mean(np.abs(source_proj_sorted - target_proj_sorted))
                distances[i] = distance
                
                # Compute gradient w.r.t. projection direction
                # Approximate gradient using finite differences
                eps = 1e-6
                grad = np.zeros(dim)
                for j in range(dim):
                    # Perturb projection in dimension j
                    perturbed = proj_dir.copy()
                    perturbed[j] += eps
                    # Normalize
                    perturbed = perturbed / np.sqrt(np.sum(perturbed**2))
                    
                    # Compute perturbed distance
                    source_proj_pert = np.dot(source, perturbed)
                    target_proj_pert = np.dot(target, perturbed)
                    source_proj_sorted_pert = np.sort(source_proj_pert)
                    target_proj_sorted_pert = np.sort(target_proj_pert)
                    dist_pert = np.mean(np.abs(source_proj_sorted_pert - target_proj_sorted_pert))
                    
                    # Finite difference
                    grad[j] = (dist_pert - distance) / eps
                
                gradients[i] = grad
                
                # Update weights based on relative distances
                weights[i] = max(1e-3, distance)
            
            # Normalize weights
            weights = weights / np.sum(weights)
            
            # Apply optimization step with momentum
            learning_rate = self.learning_rate * (1.0 / (1.0 + 0.1 * step))  # Decay learning rate
            
            # Modify gradients for correlation-awareness if enabled
            if self.correlation_aware and self.covariance is not None:
                for i in range(self.n_projections):
                    # Project gradient onto principal directions
                    if self.eigenvectors is not None and self.eigenvalues is not None:
                        # Weight by eigenvalues to emphasize important directions
                        grad_proj = np.dot(gradients[i], self.eigenvectors)
                        grad_proj = grad_proj * np.sqrt(self.eigenvalues)
                        gradients[i] = np.dot(grad_proj, self.eigenvectors.T)
            
            # Update with momentum
            velocity = self.momentum * velocity + learning_rate * gradients
            projections = projections + velocity
            
            # Normalize projections
            norms = np.sqrt(np.sum(projections**2, axis=1, keepdims=True))
            projections = projections / (norms + 1e-10)
        
        return projections, weights
    
    def fit(self, source, target):
        """
        Fit GSWD to two point sets, optimizing projections.
        
        Args:
            source: Source distribution samples
            target: Target distribution samples
        """
        # Clean inputs
        source = np.nan_to_num(source, nan=0.0, posinf=1e10, neginf=-1e10)
        target = np.nan_to_num(target, nan=0.0, posinf=1e10, neginf=-1e10)
        
        # Estimate covariance and get principal components
        if self.correlation_aware:
            # Combine samples for better estimation
            combined = np.vstack([source, target])
            self.covariance = self._estimate_covariance(combined)
            
            # Compute eigendecomposition
            try:
                self.eigenvalues, self.eigenvectors = np.linalg.eigh(self.covariance)
                
                # Ensure eigenvalues are positive (numerical stability)
                self.eigenvalues = np.maximum(self.eigenvalues, 1e-10)
                
                # Sort eigenvectors by descending eigenvalues
                idx = np.argsort(self.eigenvalues)[::-1]
                self.eigenvalues = self.eigenvalues[idx]
                self.eigenvectors = self.eigenvectors[:, idx]
                
                # Initialize projections from principal components if using PCA
                if self.projection_method == 'pca':
                    dim = source.shape[1]
                    self.projections = self.eigenvectors[:self.n_projections].T
                    
                    # Make sure we have enough projections
                    if self.projections.shape[0] < self.n_projections:
                        # Add random orthogonal directions
                        extra = np.random.randn(self.n_projections - self.projections.shape[0], dim)
                        # Orthogonalize against existing projections
                        for i in range(extra.shape[0]):
                            for j in range(self.projections.shape[0]):
                                extra[i] = extra[i] - np.dot(extra[i], self.projections[j]) * self.projections[j]
                            # Normalize
                            extra[i] = extra[i] / (np.sqrt(np.sum(extra[i]**2)) + 1e-10)
                        
                        self.projections = np.vstack([self.projections, extra])
                    
                    # Initialize weights based on eigenvalues
                    self.projection_weights = self.eigenvalues[:self.n_projections]
                    self.projection_weights = self.projection_weights / np.sum(self.projection_weights)
            except np.linalg.LinAlgError:
                warnings.warn("Eigendecomposition failed. Using random projections.")
                self.projections, self.projection_weights = self._init_projections(source.shape[1])
        
        # Optimize projections if not using PCA or if optimization is requested
        if self.projection_method != 'pca' or self.optimization_steps > 0:
            self.projections, self.projection_weights = self._optimize_projections(source, target)
        
        self.fitted = True
    
    def compute_distance(self, source, target, return_per_projection=False):
        """
        Compute the GSWD between two point sets.
        
        Args:
            source: Source distribution samples
            target: Target distribution samples
            return_per_projection: Whether to return distances per projection
            
        Returns:
            Sliced Wasserstein distance, optionally with per-projection distances
        """
        # Clean inputs
        source = np.nan_to_num(source, nan=0.0, posinf=1e10, neginf=-1e10)
        target = np.nan_to_num(target, nan=0.0, posinf=1e10, neginf=-1e10)
        
        # Fit if not already fitted or dimensions don't match
        if not self.fitted or self.projections is None or self.projections.shape[1] != source.shape[1]:
            self.fit(source, target)
        
        # Compute distance for each projection
        per_proj_distances = np.zeros(self.n_projections)
        
        for i in range(self.n_projections):
            # Project points
            proj_dir = self.projections[i]
            source_proj = np.dot(source, proj_dir)
            target_proj = np.dot(target, proj_dir)
            
            # Compute 1D Wasserstein distance
            source_proj_sorted = np.sort(source_proj)
            target_proj_sorted = np.sort(target_proj)
            distance = np.mean(np.abs(source_proj_sorted - target_proj_sorted))
            per_proj_distances[i] = distance
        
        # Weighted average of distances
        total_distance = np.sum(per_proj_distances * self.projection_weights)
        
        if return_per_projection:
            return total_distance, per_proj_distances
        else:
            return total_distance
    
    def get_regularizer(self, source, target, lambda_reg=0.1):
        """
        Compute the GSWD regularization term for particle updates.
        
        Args:
            source: Source distribution samples
            target: Target distribution samples
            lambda_reg: Regularization strength
            
        Returns:
            Regularization term for particle updates
        """
        # Clean inputs
        source = np.nan_to_num(source, nan=0.0, posinf=1e10, neginf=-1e10)
        target = np.nan_to_num(target, nan=0.0, posinf=1e10, neginf=-1e10)
        
        # Fit if not already fitted or dimensions don't match
        if not self.fitted or self.projections is None or self.projections.shape[1] != source.shape[1]:
            self.fit(source, target)
        
        # Initialize regularization term
        n_particles = target.shape[0]
        dim = target.shape[1]
        reg_term = np.zeros((n_particles, dim))
        
        # Compute regularization for each projection
        for i in range(self.n_projections):
            proj_dir = self.projections[i]
            weight = self.projection_weights[i]
            
            # Project source and target
            source_proj = np.dot(source, proj_dir)
            target_proj = np.dot(target, proj_dir)
            
            # Sort projected points
            source_indices = np.argsort(source_proj)
            target_indices = np.argsort(target_proj)
            
            # Build optimal transport plan (sort coupling)
            for j in range(n_particles):
                source_idx = source_indices[j % len(source_indices)]
                target_idx = target_indices[j]
                
                # Compute the gradient direction
                direction = (source[source_idx] - target[target_idx]) * weight
                
                # Add to regularization term
                reg_term[target_idx] += lambda_reg * direction
        
        # Clean any potential NaN or Inf
        reg_term = np.nan_to_num(reg_term, nan=0.0, posinf=0.0, neginf=0.0)
        
        return reg_term


class ImprovedESCORT:
    """
    Enhanced implementation of ESCORT with improved numerical stability
    and better handling of high-dimensional, multi-modal distributions.
    """
    def __init__(self, n_particles=100, state_dim=2, kernel_bandwidth=0.1, 
                step_size=0.01, lambda_corr=0.1, lambda_temp=0.1, 
                n_projections=10, learning_rate=1e-3, device=None, 
                discrete_actions=True, verbose=False):
        """
        Initialize the ESCORT agent.
        
        Args:
            n_particles: Number of particles for belief representation
            state_dim: Dimensionality of the state space
            kernel_bandwidth: Bandwidth for RBF kernel
            step_size: Step size for belief updates
            lambda_corr: Weight for correlation-aware regularization
            lambda_temp: Weight for temporal consistency regularization
            n_projections: Number of projections for GSWD computation
            learning_rate: Learning rate for policy optimization
            device: PyTorch device
            discrete_actions: Whether actions are discrete or continuous
            verbose: Whether to display progress information
        """
        self.n_particles = n_particles
        self.state_dim = state_dim
        self.lambda_corr = lambda_corr
        self.lambda_temp = lambda_temp
        self.verbose = verbose
        
        # Initialize device
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Create improved components for belief representation
        self.kernel = ImprovedRBFKernel(
            bandwidth=kernel_bandwidth, 
            adaptive=True,
            bandwidth_scale=0.5 * state_dim**0.5  # Scale with dimensionality
        )
        
        self.gswd = ImprovedGSWD(
            n_projections=max(n_projections, state_dim),  # At least as many as dimensions
            projection_method='optimized', 
            optimization_steps=5,
            correlation_aware=True,
            learning_rate=0.01,
            momentum=0.9
        )
        
        # Initialize the SVGD updater
        self.svgd = ImprovedSVGD(
            kernel=self.kernel,
            gswd=self.gswd,
            step_size=step_size,
            lambda_corr=lambda_corr,
            max_iter=50,
            tol=1e-5,
            verbose=verbose
        )
        
        # Initialize particles with better coverage for multi-modal distributions
        self.particles = self._initialize_particles()
        
        # Store previous belief for temporal consistency
        self.prev_particles = self.particles.copy()
        self.prev_action = None
        self.prev_observation = None
        
        # Initialize action dimension
        self.action_dim = None
        
        # Flag to track if policy is initialized
        self.policy_initialized = False
        
        # Set up policy network parameters (will initialize when action_dim is known)
        self.hidden_dim = 128  # Default hidden dimension
        self.learning_rate = learning_rate
        self.discrete_actions = discrete_actions
        
        # To store policy network
        self.policy_network = None
        self.optimizer = None
        
        # For tracking performance
        self.training_loss = []
    
    def _initialize_particles(self):
        """
        Initialize particles with better coverage of the state space.
        
        Returns:
            np.ndarray: Initial particle states
        """
        # Random initialization with appropriate scale for the state space
        particles = np.random.normal(0.5, 0.1, (self.n_particles, self.state_dim))
        
        # For low-dimensional cases, ensure better multi-modal coverage
        if self.state_dim <= 3:
            # Divide particles into groups to cover different regions
            n_groups = min(4, 2**self.state_dim)
            particles_per_group = self.n_particles // n_groups
            remainder = self.n_particles - particles_per_group * n_groups
            
            # Define group centers based on state space corners
            group_centers = []
            for i in range(n_groups):
                center = np.zeros(self.state_dim)
                # Create binary pattern for corners (e.g., (0,0), (0,1), (1,0), (1,1) in 2D)
                for j in range(self.state_dim):
                    center[j] = 1.0 if (i & (1 << j)) > 0 else 0.0
                group_centers.append(center)
            
            # Place particles around group centers
            for i in range(n_groups):
                start_idx = i * particles_per_group
                end_idx = start_idx + particles_per_group
                
                # Add appropriate noise based on dimensionality
                noise_scale = 0.1 * (1.0 + 0.5 * self.state_dim)
                particles[start_idx:end_idx] = group_centers[i] + np.random.randn(particles_per_group, self.state_dim) * noise_scale
            
            # Randomly place remaining particles
            if remainder > 0:
                particles[-remainder:] = np.random.rand(remainder, self.state_dim)
        
        return particles
    
    def _build_policy_network(self):
        """
        Build the policy network architecture with improved belief encoding.
        
        Returns:
            torch.nn.Module: PyTorch neural network
        """
        class ImprovedParticleEncoder(torch.nn.Module):
            def __init__(self, state_dim, hidden_dim, action_dim, discrete_actions):
                super(ImprovedParticleEncoder, self).__init__()
                
                # Increase network capacity for better belief representation
                self.particle_encoder = torch.nn.Sequential(
                    torch.nn.Linear(state_dim, hidden_dim),
                    torch.nn.LayerNorm(hidden_dim),  # Add normalization for stability
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_dim, hidden_dim),
                    torch.nn.LayerNorm(hidden_dim),
                    torch.nn.ReLU()
                )
                
                # Self-attention for better particle interaction
                self.attention = ParticleSelfAttention(hidden_dim)
                
                # Process pooled belief
                self.belief_encoder = torch.nn.Sequential(
                    torch.nn.Linear(hidden_dim, hidden_dim),
                    torch.nn.LayerNorm(hidden_dim),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_dim, hidden_dim),
                    torch.nn.ReLU()
                )
                
                # Output layers
                if discrete_actions:
                    self.action_head = torch.nn.Linear(hidden_dim, action_dim)
                else:
                    # For continuous actions, output mean and log_std
                    self.action_mean = torch.nn.Linear(hidden_dim, action_dim)
                    self.action_log_std = torch.nn.Linear(hidden_dim, action_dim)
                
                self.discrete_actions = discrete_actions
                self.hidden_dim = hidden_dim
            
            def forward(self, particles):
                """
                Forward pass through the network.
                
                Args:
                    particles: Belief particles of shape [batch_size, n_particles, state_dim]
                
                Returns:
                    Action probabilities or parameters
                """
                batch_size, n_particles, state_dim = particles.shape
                
                # Process each particle - reshape to [batch_size * n_particles, state_dim]
                flat_particles = particles.reshape(-1, state_dim)
                
                # Encode each particle
                particle_features = self.particle_encoder(flat_particles)
                
                # Reshape back for attention
                particle_features = particle_features.view(batch_size, n_particles, self.hidden_dim)
                
                # Apply self-attention for better particle interaction
                attended_features = self.attention(particle_features)
                
                # Pooling with attention weights
                belief_features = torch.mean(attended_features, dim=1)
                
                # Encode belief representation
                belief_encoded = self.belief_encoder(belief_features)
                
                if self.discrete_actions:
                    # For discrete actions, output logits
                    action_logits = self.action_head(belief_encoded)
                    return torch.nn.functional.softmax(action_logits, dim=-1)
                else:
                    # For continuous actions, output mean and log_std
                    action_mean = self.action_mean(belief_encoded)
                    action_log_std = self.action_log_std(belief_encoded)
                    action_log_std = torch.clamp(action_log_std, -20, 2)  # Numerical stability
                    return action_mean, action_log_std
        
        class ParticleSelfAttention(torch.nn.Module):
            def __init__(self, hidden_dim, num_heads=4):
                super(ParticleSelfAttention, self).__init__()
                
                # Multi-head attention
                self.attention = torch.nn.MultiheadAttention(
                    embed_dim=hidden_dim,
                    num_heads=num_heads,
                    batch_first=True
                )
                
                # Output projection
                self.output_projection = torch.nn.Sequential(
                    torch.nn.Linear(hidden_dim, hidden_dim),
                    torch.nn.ReLU()
                )
            
            def forward(self, x):
                """
                Apply self-attention to particle features.
                
                Args:
                    x: Particle features of shape [batch_size, n_particles, hidden_dim]
                
                Returns:
                    Attended features of same shape
                """
                # Apply multi-head attention
                attended, _ = self.attention(x, x, x)
                
                # Apply output projection
                output = self.output_projection(attended)
                
                return output
        
        if self.action_dim is None:
            raise ValueError("Action dimension must be set before building policy network")
            
        return ImprovedParticleEncoder(
            self.state_dim, 
            self.hidden_dim, 
            self.action_dim, 
            self.discrete_actions
        ).to(self.device)
    
    def _initialize_policy(self, action_dim):
        """
        Initialize the policy network components.
        
        Args:
            action_dim: Dimension of the action space
        """
        self.action_dim = action_dim
        self.policy_network = self._build_policy_network()
        
        # Use Adam optimizer with weight decay for regularization
        self.optimizer = torch.optim.AdamW(
            self.policy_network.parameters(),
            lr=self.learning_rate,
            weight_decay=1e-4
        )
        
        self.policy_initialized = True
        
        if self.verbose:
            print(f"Initialized ImprovedESCORT policy network with action_dim={action_dim}")
    
    def _compute_score_function(self, particles, observation, observation_model):
        """
        Compute score function (gradient of log-likelihood) with robustness improvements.
        
        Args:
            particles: Current particle states
            observation: Current observation
            observation_model: Function that computes observation likelihoods
            
        Returns:
            Function that computes score values
        """
        def score_fn(particles, return_logp=False):
            """
            Score function that computes gradient of log-likelihood.
            
            Args:
                particles: Particle states
                return_logp: Whether to return log probabilities
                
            Returns:
                Score values (and optionally log probabilities)
            """
            n_particles = particles.shape[0]
            
            # Initialize arrays for score and log probabilities
            scores = np.zeros((n_particles, self.state_dim))
            log_probs = np.zeros(n_particles) if return_logp else None
            
            # Compute numerical gradient of log-likelihood
            for i, particle in enumerate(particles):
                # Clean particle from NaN or Inf
                particle = np.nan_to_num(particle, nan=0.0, posinf=1e10, neginf=-1e10)
                
                # Compute observation likelihood
                try:
                    likelihood = observation_model(particle, observation)
                    
                    # Handle very small likelihoods for numerical stability
                    likelihood = max(likelihood, 1e-15)
                    
                    if return_logp:
                        log_probs[i] = np.log(likelihood)
                    
                    # Adaptive step size for finite difference based on particle scale
                    scale = np.median(np.abs(particle)) if np.any(particle != 0) else 1.0
                    eps = max(1e-6, 1e-4 * scale)
                    
                    # For each dimension, compute gradient using second-order central difference
                    for d in range(self.state_dim):
                        # Create perturbed particles
                        particle_plus = particle.copy()
                        particle_plus[d] += eps
                        
                        particle_minus = particle.copy()
                        particle_minus[d] -= eps
                        
                        # Compute likelihoods
                        likelihood_plus = observation_model(particle_plus, observation)
                        likelihood_minus = observation_model(particle_minus, observation)
                        
                        # Handle very small likelihoods
                        likelihood_plus = max(likelihood_plus, 1e-15)
                        likelihood_minus = max(likelihood_minus, 1e-15)
                        
                        # Second-order central difference for more accuracy
                        log_derivative = (np.log(likelihood_plus) - np.log(likelihood_minus)) / (2 * eps)
                        scores[i, d] = log_derivative
                        
                except Exception as e:
                    # Handle any exceptions in likelihood computation
                    warnings.warn(f"Error computing score: {e}")
                    # Keep score at zero for this particle
                    if return_logp:
                        log_probs[i] = -30  # Very low log probability
            
            # Clip extreme values for numerical stability
            max_score = 100.0
            scores = np.clip(scores, -max_score, max_score)
            
            if return_logp:
                return scores, log_probs
            else:
                return scores
        
        return score_fn
    
    def update(self, action, observation, transition_model, observation_model):
        """
        Update belief using the improved ESCORT framework.
        
        Args:
            action: Action taken
            observation: Observation received
            transition_model: Function that simulates state transitions
            observation_model: Function that computes observation likelihoods
        """
        # Store previous action and observation for temporal consistency
        self.prev_action = action
        self.prev_observation = observation
        self.prev_particles = self.particles.copy()
        
        # Apply transition model to predict particle movements
        predicted_particles = np.zeros_like(self.particles)
        
        try:
            # Apply transition for each particle
            for i in range(self.n_particles):
                predicted_particles[i] = transition_model(self.particles[i], action)
            
            # Add small noise to prevent particles from collapsing
            noise_scale = 0.01 * (1.0 + 0.1 * self.state_dim)  # Scale with dimensionality
            predicted_particles += np.random.randn(*predicted_particles.shape) * noise_scale
            
            # Detect and handle any NaN or Inf values
            if np.any(~np.isfinite(predicted_particles)):
                warnings.warn("NaN or Inf detected in predicted particles. Applying fix.")
                predicted_particles = np.nan_to_num(
                    predicted_particles, 
                    nan=np.mean(self.particles, axis=0),  # Replace NaN with mean
                    posinf=1e10, 
                    neginf=-1e10
                )
            
            # Create score function for update
            score_fn = self._compute_score_function(predicted_particles, observation, observation_model)
            
            # Apply improved SVGD update
            if self.lambda_temp > 0 and self.prev_particles is not None:
                # Update with both correlation and temporal regularization
                self.particles = self.svgd.update(
                    predicted_particles,
                    score_fn,
                    self.prev_particles,
                    self.lambda_temp
                )
            else:
                # Update with only correlation regularization
                self.particles = self.svgd.update(
                    predicted_particles,
                    score_fn
                )
            
            # Final check for numerical issues
            if np.any(~np.isfinite(self.particles)):
                warnings.warn("NaN or Inf detected in updated particles. Applying fix.")
                # Replace problematic particles with previous values
                bad_mask = ~np.all(np.isfinite(self.particles), axis=1)
                self.particles[bad_mask] = self.prev_particles[bad_mask]
                
        except Exception as e:
            warnings.warn(f"Error in belief update: {e}")
            # In case of error, keep previous particles
            self.particles = self.prev_particles.copy()
    
    def get_belief_estimate(self):
        """
        Get the current belief particles.
        
        Returns:
            np.ndarray: Current belief particles
        """
        return self.particles
    
    def get_mode_statistics(self):
        """
        Compute statistics about the modes in the belief distribution
        with improved clustering.
        
        Returns:
            dict: Statistics about modes
        """
        try:
            # Use improved clustering method
            particles = self.particles
            
            # Estimate number of clusters based on dimensionality
            n_clusters = min(8, max(2, self.state_dim + 1))
            
            # Use KMeans with multiple initializations for better results
            kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
            labels = kmeans.fit_predict(particles)
            
            # Get cluster centers
            centers = kmeans.cluster_centers_
            
            # Compute statistics for each cluster
            unique_labels = np.unique(labels)
            weights = np.zeros(len(unique_labels))
            covariances = []
            
            for i, label in enumerate(unique_labels):
                # Get particles in this cluster
                cluster_particles = particles[labels == label]
                
                # Compute weight (proportion of particles)
                weights[i] = len(cluster_particles) / len(particles)
                
                # Compute covariance with regularization for numerical stability
                if len(cluster_particles) > 1:
                    # Center the data
                    centered = cluster_particles - centers[i]
                    
                    # Compute raw covariance
                    cov = np.dot(centered.T, centered) / len(cluster_particles)
                    
                    # Add small regularization for stability
                    reg = 1e-6 * np.eye(self.state_dim)
                    cov = cov + reg
                    
                    # Ensure symmetry
                    cov = (cov + cov.T) / 2
                    
                    covariances.append(cov)
                else:
                    # Use identity matrix if only one particle
                    covariances.append(np.eye(self.state_dim) * 0.01)
            
            return {
                'num_modes': len(unique_labels),
                'mode_weights': weights,
                'mode_centers': centers,
                'mode_covariances': covariances
            }
            
        except Exception as e:
            warnings.warn(f"Error in mode detection: {e}. Using fallback.")
            
            # Fallback: use simple statistics
            mean = np.mean(self.particles, axis=0)
            cov = np.cov(self.particles, rowvar=False)
            
            # Add regularization
            cov = cov + 1e-6 * np.eye(self.state_dim)
            
            return {
                'num_modes': 1,
                'mode_weights': np.array([1.0]),
                'mode_centers': np.array([mean]),
                'mode_covariances': [cov]
            }
    
    def select_action(self, action_space=None, deterministic=False):
        """
        Select an action based on current belief using the policy network.
        
        Args:
            action_space: The action space (used for initialization if needed)
            deterministic: Whether to select deterministically or sample
            
        Returns:
            Action to take
        """
        # Initialize policy if not already initialized
        if not self.policy_initialized:
            if action_space is None:
                raise ValueError("Action space must be provided on first call to select_action")
            
            # Determine action dimension from action space
            if self.discrete_actions:
                if isinstance(action_space, (list, np.ndarray)):
                    self.action_dim = len(action_space)
                elif hasattr(action_space, 'n'):
                    self.action_dim = action_space.n
                else:
                    raise ValueError("Cannot determine action dimension from provided action space")
            else:
                if hasattr(action_space, 'shape'):
                    self.action_dim = action_space.shape[0]
                else:
                    raise ValueError("Cannot determine action dimension from provided action space")
            
            # Initialize policy network
            self._initialize_policy(self.action_dim)
        
        # Get current belief particles
        belief_particles = self.particles
        
        # Convert to PyTorch tensor and add batch dimension
        if isinstance(belief_particles, np.ndarray):
            # Clean any NaN or Inf values
            belief_particles = np.nan_to_num(
                belief_particles, 
                nan=0.0, 
                posinf=1e10, 
                neginf=-1e10
            )
            
            # Add batch dimension if needed
            if len(belief_particles.shape) == 2:
                belief_particles = belief_particles[np.newaxis, :, :]
                
            belief_particles = torch.FloatTensor(belief_particles).to(self.device)
        
        # Pass through policy network
        self.policy_network.eval()
        with torch.no_grad():
            if self.discrete_actions:
                action_probs = self.policy_network(belief_particles)
                
                if deterministic:
                    # Select action with highest probability
                    action = torch.argmax(action_probs, dim=-1)
                else:
                    # Sample from probability distribution
                    action_dist = torch.distributions.Categorical(action_probs)
                    action = action_dist.sample()
                
                # Move to CPU and convert to numpy
                action = action.cpu().numpy().squeeze()
            else:
                # For continuous actions
                action_mean, action_log_std = self.policy_network(belief_particles)
                
                if deterministic:
                    # Use mean directly
                    action = action_mean
                else:
                    # Sample from normal distribution
                    action_std = torch.exp(action_log_std)
                    action_dist = torch.distributions.Normal(action_mean, action_std)
                    action = action_dist.sample()
                
                # Move to CPU and convert to numpy
                action = action.cpu().numpy().squeeze()
                
                # Clip extreme values for stability
                action = np.clip(action, -10.0, 10.0)
        
        return action
    
    def reset(self, initial_particles=None):
        """
        Reset the belief state.
        
        Args:
            initial_particles: Initial particle states (if None, use random initialization)
        """
        if initial_particles is not None:
            if isinstance(initial_particles, np.ndarray) and initial_particles.shape == (self.n_particles, self.state_dim):
                self.particles = initial_particles.copy()
            else:
                raise ValueError(f"Initial particles must have shape ({self.n_particles}, {self.state_dim})")
        else:
            # Random initialization
            self.particles = self._initialize_particles()
        
        # Reset previous particles
        self.prev_particles = self.particles.copy()
        self.prev_action = None
        self.prev_observation = None


class ImprovedSVGD:
    """
    Improved SVGD implementation with better multi-modal handling,
    correlation awareness, and numerical stability.
    """
    def __init__(self, kernel, gswd, step_size=0.01, lambda_corr=0.1, lambda_temp=0.1,
                max_iter=50, tol=1e-5, verbose=False):
        """
        Initialize the improved SVGD.
        
        Args:
            kernel: Kernel function for SVGD
            gswd: GSWD instance for correlation-aware updates
            step_size: Step size for updates
            lambda_corr: Weight for correlation-aware regularization
            lambda_temp: Weight for temporal consistency regularization
            max_iter: Maximum number of iterations
            tol: Convergence tolerance
            verbose: Whether to display progress information
        """
        self.kernel = kernel
        self.gswd = gswd
        self.step_size = step_size
        self.lambda_corr = lambda_corr
        self.lambda_temp = lambda_temp
        self.max_iter = max_iter
        self.tol = tol
        self.verbose = verbose
    
    def _compute_svgd_update(self, particles, score_fn):
        """
        Compute SVGD update with numerical safeguards.
        
        Args:
            particles: Current particle states
            score_fn: Score function (gradient of log-likelihood)
            
        Returns:
            Update directions for particles
        """
        # Get dimensions
        n_particles, dim = particles.shape
        
        # Compute kernel matrix and its gradient
        try:
            K = self.kernel.evaluate(particles)
            grad_K = self.kernel.gradient(particles)
            
            # Check for NaN or Inf in kernel matrix
            if not np.all(np.isfinite(K)) or not np.all(np.isfinite(grad_K)):
                warnings.warn("NaN or Inf detected in kernel computation. Using fallback.")
                # Use simple RBF kernel with fixed bandwidth
                bandwidth = 0.1 * dim
                
                # Compute pairwise distances
                sq_dists = np.zeros((n_particles, n_particles))
                for i in range(n_particles):
                    for j in range(i, n_particles):
                        diff = particles[i] - particles[j]
                        sq_dist = np.sum(diff * diff)
                        sq_dists[i, j] = sq_dist
                        sq_dists[j, i] = sq_dist
                
                # Compute kernel matrix
                K = np.exp(-sq_dists / (2 * bandwidth))
                
                # Compute kernel gradient
                grad_K = np.zeros((n_particles, n_particles, dim))
                for i in range(n_particles):
                    for j in range(n_particles):
                        diff = particles[i] - particles[j]
                        grad_K[i, j] = K[i, j] * (-diff / bandwidth)
        except Exception as e:
            warnings.warn(f"Error in kernel computation: {e}. Using fallback.")
            # Use simple RBF kernel with fixed bandwidth
            bandwidth = 0.1 * dim
            
            # Compute pairwise distances
            sq_dists = np.zeros((n_particles, n_particles))
            for i in range(n_particles):
                for j in range(i, n_particles):
                    diff = particles[i] - particles[j]
                    sq_dist = np.sum(diff * diff)
                    sq_dists[i, j] = sq_dist
                    sq_dists[j, i] = sq_dist
            
            # Compute kernel matrix
            K = np.exp(-sq_dists / (2 * bandwidth))
            
            # Compute kernel gradient
            grad_K = np.zeros((n_particles, n_particles, dim))
            for i in range(n_particles):
                for j in range(n_particles):
                    diff = particles[i] - particles[j]
                    grad_K[i, j] = K[i, j] * (-diff / bandwidth)
        
        # Compute score function at particles
        try:
            # Check if score_fn supports returning log probabilities
            try:
                score_values, log_probs = score_fn(particles, return_logp=True)
            except:
                score_values = score_fn(particles)
            
            # Check for NaN or Inf in score values
            if not np.all(np.isfinite(score_values)):
                warnings.warn("NaN or Inf detected in score computation. Fixing.")
                # Replace NaN or Inf with zeros
                score_values = np.nan_to_num(score_values, nan=0.0, posinf=0.0, neginf=0.0)
                
                # Clip extreme values
                max_score = 100.0
                score_values = np.clip(score_values, -max_score, max_score)
        except Exception as e:
            warnings.warn(f"Error in score computation: {e}. Using zeros.")
            # Use zeros as fallback
            score_values = np.zeros((n_particles, dim))
        
        # Initialize update
        update = np.zeros((n_particles, dim))
        
        # Compute update for each particle
        for i in range(n_particles):
            # Compute attractive term: K(x_j, x_i) ∇log p(x_j)
            attractive = np.sum(K[i, :, np.newaxis] * score_values, axis=0)
            
            # Compute repulsive term: ∇_x_j K(x_j, x_i)
            repulsive = np.sum(grad_K[:, i, :], axis=0)
            
            # Enhance repulsive forces in high dimensions
            repulsion_factor = 1.0 + 0.1 * dim  # Scale with dimensionality
            repulsive = repulsive * repulsion_factor
            
            # Combine attractive and repulsive terms
            update[i] = attractive + repulsive
        
        # Normalize by number of particles
        update /= n_particles
        
        # Check for NaN or Inf in update
        if not np.all(np.isfinite(update)):
            warnings.warn("NaN or Inf detected in update. Fixing.")
            # Replace NaN or Inf with zeros
            update = np.nan_to_num(update, nan=0.0, posinf=0.0, neginf=0.0)
            
            # Clip extreme values
            max_update = 10.0
            update = np.clip(update, -max_update, max_update)
        
        return update
    
    def _compute_temporal_consistency(self, particles, prev_particles, lambda_temp):
        """
        Compute temporal consistency regularization.
        
        Args:
            particles: Current particle states
            prev_particles: Previous particle states
            lambda_temp: Weight for temporal consistency
            
        Returns:
            Regularization term
        """
        if prev_particles is None or lambda_temp <= 0:
            return np.zeros_like(particles)
        
        try:
            # Compute optimal transport mapping from previous to current particles
            # using GSWD for more efficient computation
            reg_term = self.gswd.get_regularizer(prev_particles, particles, lambda_temp)
            
            # Check for NaN or Inf
            if not np.all(np.isfinite(reg_term)):
                warnings.warn("NaN or Inf detected in temporal consistency. Fixing.")
                reg_term = np.nan_to_num(reg_term, nan=0.0, posinf=0.0, neginf=0.0)
                
                # Clip extreme values
                max_reg = 10.0
                reg_term = np.clip(reg_term, -max_reg, max_reg)
                
            return reg_term
        except Exception as e:
            warnings.warn(f"Error in temporal consistency: {e}. Using fallback.")
            
            # Simple fallback: push particles toward previous positions
            n_particles = particles.shape[0]
            reg_term = np.zeros_like(particles)
            
            # Match particles to previous ones
            for i in range(n_particles):
                # Find closest previous particle
                diffs = prev_particles - particles[i]
                sq_dists = np.sum(diffs**2, axis=1)
                closest_idx = np.argmin(sq_dists)
                
                # Push toward closest previous particle
                reg_term[i] = lambda_temp * (prev_particles[closest_idx] - particles[i])
            
            return reg_term
    
    def _detect_and_resample_particles(self, particles, score_fn):
        """
        Detect and resample particles in low-density regions.
        
        Args:
            particles: Current particle states
            score_fn: Score function
            
        Returns:
            Updated particles
        """
        n_particles, dim = particles.shape
        
        # Skip if too few particles for meaningful resampling
        if n_particles < 10:
            return particles
        
        try:
            # Try to get density information from score_fn
            try:
                _, log_probs = score_fn(particles, return_logp=True)
                densities = np.exp(log_probs)
            except:
                # Fallback to kernel density estimation
                K = self.kernel.evaluate(particles)
                densities = np.sum(K, axis=1) / n_particles
            
            # Normalize densities
            max_density = np.max(densities)
            if max_density > 0:
                normalized_densities = densities / max_density
            else:
                normalized_densities = np.ones(n_particles)
            
            # Find particles with very low density
            threshold = 0.1  # Threshold for low density
            low_density = normalized_densities < threshold
            n_low = np.sum(low_density)
            
            # If enough low-density particles found, resample them
            if n_low > 0:
                # Find high-density particles
                high_density = normalized_densities > np.median(normalized_densities)
                high_indices = np.where(high_density)[0]
                
                # If no high-density particles found, use random selection
                if len(high_indices) == 0:
                    high_indices = np.random.choice(n_particles, size=n_particles//2, replace=False)
                
                # Resample low-density particles
                low_indices = np.where(low_density)[0]
                for idx in low_indices:
                    # Sample from high-density particles
                    source_idx = np.random.choice(high_indices)
                    
                    # Copy with noise
                    noise_scale = 0.1 * (1.0 + 0.05 * dim)  # Scale with dimensionality
                    particles[idx] = particles[source_idx] + np.random.randn(dim) * noise_scale
            
            return particles
        except Exception as e:
            warnings.warn(f"Error in particle resampling: {e}")
            return particles
    
    def update(self, particles, score_fn, prev_particles=None, lambda_temp=None):
        """
        Update particles using improved SVGD.
        
        Args:
            particles: Initial particle states
            score_fn: Score function (gradient of log-likelihood)
            prev_particles: Previous particle states for temporal consistency
            lambda_temp: Weight for temporal consistency (overrides self.lambda_temp)
            
        Returns:
            Updated particles
        """
        # Make a copy of particles to avoid modifying the input
        particles = particles.copy()
        
        # Use provided lambda_temp if given, otherwise use default
        lambda_temp = lambda_temp if lambda_temp is not None else self.lambda_temp
        
        # Apply SVGD iterations
        for t in range(self.max_iter):
            # Compute SVGD update
            svgd_update = self._compute_svgd_update(particles, score_fn)
            
            # Add temporal consistency if previous particles provided
            if prev_particles is not None and lambda_temp > 0:
                temp_update = self._compute_temporal_consistency(particles, prev_particles, lambda_temp)
                update = svgd_update + temp_update
            else:
                update = svgd_update
            
            # Apply update
            new_particles = particles + self.step_size * update
            
            # Periodically check for stuck particles and resample
            if t > 0 and t % 10 == 0:
                new_particles = self._detect_and_resample_particles(new_particles, score_fn)
            
            # Check convergence
            diff = np.linalg.norm(new_particles - particles) / particles.shape[0]
            
            # Update particles
            particles = new_particles
            
            # Check for convergence
            if diff < self.tol:
                if self.verbose:
                    print(f"SVGD converged after {t+1} iterations")
                break
        
        return particles
