import time
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import pairwise_distances
import warnings

class RobustSVGD:
    """
    Robust SVGD implementation with improved numerical stability and better
    handling of multi-modal distributions in high dimensions.
    
    This class serves as a baseline comparison to ESCORT, implementing standard
    SVGD with enhanced robustness features but without the correlation-aware 
    projections and temporal consistency constraints of ESCORT.
    """
    
    def __init__(self, n_particles=100, state_dim=2, kernel_bandwidth=0.1, 
                step_size=0.01, max_iter=50, tol=1e-5, verbose=False,
                adaptive_bandwidth=True, enhanced_repulsion=True,
                resampling_threshold=0.1):
        """
        Initialize the robust SVGD agent.
        
        Args:
            n_particles: Number of particles for belief representation
            state_dim: Dimensionality of the state space
            kernel_bandwidth: Initial bandwidth for RBF kernel
            step_size: Step size for belief updates
            max_iter: Maximum number of iterations
            tol: Convergence tolerance
            verbose: Whether to display progress information
            adaptive_bandwidth: Whether to adapt kernel bandwidth
            enhanced_repulsion: Whether to enhance repulsive forces
            resampling_threshold: Threshold for particle resampling
        """
        self.n_particles = n_particles
        self.state_dim = state_dim
        self.kernel_bandwidth = kernel_bandwidth
        self.step_size = step_size
        self.max_iter = max_iter
        self.tol = tol
        self.verbose = verbose
        self.adaptive_bandwidth = adaptive_bandwidth
        self.enhanced_repulsion = enhanced_repulsion
        self.resampling_threshold = resampling_threshold
        
        # Initialize particles with better coverage
        self.particles = self._initialize_particles()
        
        # Initialize action dimension
        self.action_dim = None
        
        # Policy network settings
        self.policy_initialized = False
        self.hidden_dim = 128
        self.learning_rate = 1e-3
        self.discrete_actions = True
        
        # Store policy network
        self.policy_network = None
        self.optimizer = None
        
        # For tracking training performance
        self.training_loss = []
    
    def _initialize_particles(self):
        """
        Initialize particles with better coverage of the state space.
        
        Returns:
            np.ndarray: Initial particle states
        """
        # Base initialization with appropriate scale
        particles = np.random.normal(0.5, 0.1, (self.n_particles, self.state_dim))
        
        # For low-dimensional cases, ensure better coverage
        if self.state_dim <= 3:
            # Split particles to cover different regions
            n_groups = min(4, 2**self.state_dim)
            particles_per_group = self.n_particles // n_groups
            remainder = self.n_particles - particles_per_group * n_groups
            
            # Create group centers
            group_centers = []
            for i in range(n_groups):
                center = np.zeros(self.state_dim)
                # Create binary pattern for corners
                for j in range(self.state_dim):
                    center[j] = 1.0 if (i & (1 << j)) > 0 else 0.0
                group_centers.append(center)
            
            # Place particles around centers
            for i in range(n_groups):
                start_idx = i * particles_per_group
                end_idx = start_idx + particles_per_group
                
                # Add noise scaled to dimensionality
                noise_scale = 0.1 * (1.0 + 0.1 * self.state_dim)
                particles[start_idx:end_idx] = group_centers[i] + np.random.randn(particles_per_group, self.state_dim) * noise_scale
            
            # Randomly place remaining particles
            if remainder > 0:
                particles[-remainder:] = np.random.rand(remainder, self.state_dim)
        
        return particles
    
    def _compute_rbf_kernel(self, x, bandwidth=None):
        """
        Compute RBF kernel matrix with numerical safeguards.
        
        Args:
            x: Input points of shape (n, d)
            bandwidth: Kernel bandwidth (if None, use adaptive or default)
            
        Returns:
            Kernel matrix and its gradient
        """
        n, d = x.shape
        
        # Use adaptive bandwidth if enabled
        if bandwidth is None:
            if self.adaptive_bandwidth:
                # Compute pairwise distances
                try:
                    dists = pairwise_distances(x, metric='euclidean')
                    
                    # Use median heuristic for bandwidth
                    # Scale with dimensionality to prevent collapse in high dimensions
                    bandwidth = np.median(dists) * np.sqrt(d) / 2
                    
                    # Ensure bandwidth is not too small or too large
                    bandwidth = max(0.05, min(bandwidth, 10.0))
                except Exception as e:
                    warnings.warn(f"Error computing adaptive bandwidth: {e}")
                    # Fallback to default bandwidth scaled with dimensionality
                    bandwidth = self.kernel_bandwidth * np.sqrt(d)
            else:
                # Use default bandwidth scaled with dimensionality
                bandwidth = self.kernel_bandwidth * np.sqrt(d)
        
        # Ensure bandwidth is positive
        bandwidth = max(1e-6, bandwidth)
        
        # Compute pairwise squared distances
        sq_dists = np.zeros((n, n))
        
        for i in range(n):
            for j in range(i, n):
                # Compute squared distance safely
                diff = x[i] - x[j]
                sq_dist = np.sum(diff * diff)
                sq_dists[i, j] = sq_dist
                sq_dists[j, i] = sq_dist
        
        # Compute kernel matrix
        K = np.exp(-sq_dists / (2 * bandwidth))
        
        # Compute kernel gradient
        grad_K = np.zeros((n, n, d))
        
        for i in range(n):
            for j in range(n):
                diff = x[i] - x[j]
                grad_K[i, j] = K[i, j] * (-diff / bandwidth)
        
        # Check for NaN or Inf
        if not np.all(np.isfinite(K)) or not np.all(np.isfinite(grad_K)):
            warnings.warn("NaN or Inf detected in kernel computation. Fixing.")
            K = np.nan_to_num(K, nan=0.0, posinf=0.0, neginf=0.0)
            grad_K = np.nan_to_num(grad_K, nan=0.0, posinf=0.0, neginf=0.0)
        
        return K, grad_K, bandwidth
    
    def _compute_svgd_update(self, particles, score_fn):
        """
        Compute SVGD update with improved robustness.
        
        Args:
            particles: Particle states of shape (n_particles, state_dim)
            score_fn: Score function (gradient of log likelihood)
            
        Returns:
            Update directions for particles
        """
        n_particles, dim = particles.shape
        
        # Compute kernel matrix and gradient
        K, grad_K, bandwidth = self._compute_rbf_kernel(particles)
        
        # Compute score function at particles
        try:
            # Check if score_fn supports returning log probs
            try:
                score_values, log_probs = score_fn(particles, return_logp=True)
            except:
                score_values = score_fn(particles)
            
            # Check for NaN or Inf
            if not np.all(np.isfinite(score_values)):
                warnings.warn("NaN or Inf detected in score computation. Fixing.")
                score_values = np.nan_to_num(score_values, nan=0.0, posinf=0.0, neginf=0.0)
                
                # Clip extreme values
                max_score = 100.0
                score_values = np.clip(score_values, -max_score, max_score)
        except Exception as e:
            warnings.warn(f"Error in score computation: {e}. Using zeros.")
            score_values = np.zeros((n_particles, dim))
        
        # Initialize update
        update = np.zeros((n_particles, dim))
        
        # Compute update for each particle
        for i in range(n_particles):
            # Compute attractive term: K(x_j, x_i) ∇log p(x_j)
            attractive = np.sum(K[i, :, np.newaxis] * score_values, axis=0)
            
            # Compute repulsive term: ∇_x_j K(x_j, x_i)
            repulsive = np.sum(grad_K[:, i, :], axis=0)
            
            # Enhance repulsive forces in high dimensions if enabled
            if self.enhanced_repulsion:
                # Scale with dimensionality to maintain effectiveness
                repulsion_factor = 1.0 + 0.1 * dim
                repulsive = repulsive * repulsion_factor
            
            # Combine attractive and repulsive terms
            update[i] = attractive + repulsive
        
        # Normalize by number of particles
        update /= n_particles
        
        # Check for NaN or Inf
        if not np.all(np.isfinite(update)):
            warnings.warn("NaN or Inf detected in update. Fixing.")
            update = np.nan_to_num(update, nan=0.0, posinf=0.0, neginf=0.0)
            
            # Clip extreme values
            max_update = 10.0
            update = np.clip(update, -max_update, max_update)
        
        return update
    
    def _detect_and_resample_particles(self, particles, score_fn):
        """
        Detect and resample particles in low-density regions.
        
        Args:
            particles: Current particle states
            score_fn: Score function
            
        Returns:
            Updated particles
        """
        n_particles, dim = particles.shape
        
        # Skip if too few particles
        if n_particles < 10:
            return particles
        
        try:
            # Get density information
            try:
                _, log_probs = score_fn(particles, return_logp=True)
                densities = np.exp(log_probs)
            except:
                # Fallback to kernel density estimation
                K, _, _ = self._compute_rbf_kernel(particles)
                densities = np.sum(K, axis=1) / n_particles
            
            # Normalize densities
            max_density = np.max(densities)
            if max_density > 0:
                normalized_densities = densities / max_density
            else:
                normalized_densities = np.ones(n_particles)
            
            # Find particles with low density
            threshold = self.resampling_threshold
            low_density = normalized_densities < threshold
            n_low = np.sum(low_density)
            
            # Resample if enough low-density particles found
            if n_low > 0:
                # Identify high-density particles
                high_density = normalized_densities > np.median(normalized_densities)
                high_indices = np.where(high_density)[0]
                
                # If no high-density particles, use random selection
                if len(high_indices) == 0:
                    high_indices = np.random.choice(n_particles, size=n_particles//2, replace=False)
                
                # Resample low-density particles
                low_indices = np.where(low_density)[0]
                for idx in low_indices:
                    # Sample from high-density particles
                    source_idx = np.random.choice(high_indices)
                    
                    # Copy with noise scaled to dimensionality
                    noise_scale = 0.1 * (1.0 + 0.05 * dim)
                    particles[idx] = particles[source_idx] + np.random.randn(dim) * noise_scale
            
            return particles
        except Exception as e:
            warnings.warn(f"Error in particle resampling: {e}")
            return particles
    
    def _compute_score_function(self, particles, observation, observation_model):
        """
        Compute score function with improved gradient estimation.
        
        Args:
            particles: Current particle states
            observation: Current observation
            observation_model: Function that computes observation likelihoods
            
        Returns:
            Function that computes score values
        """
        def score_fn(particles, return_logp=False):
            """
            Compute gradient of log-likelihood.
            
            Args:
                particles: Particle states
                return_logp: Whether to return log probabilities
                
            Returns:
                Score values (and optionally log probabilities)
            """
            n_particles = particles.shape[0]
            
            # Initialize arrays
            scores = np.zeros((n_particles, self.state_dim))
            log_probs = np.zeros(n_particles) if return_logp else None
            
            # Compute gradients
            for i, particle in enumerate(particles):
                # Clean particle from NaN or Inf
                particle = np.nan_to_num(particle, nan=0.0, posinf=1e10, neginf=-1e10)
                
                try:
                    # Compute likelihood
                    likelihood = observation_model(particle, observation)
                    
                    # Handle small likelihoods
                    likelihood = max(likelihood, 1e-15)
                    
                    if return_logp:
                        log_probs[i] = np.log(likelihood)
                    
                    # Adaptive step size for finite difference
                    scale = np.median(np.abs(particle)) if np.any(particle != 0) else 1.0
                    eps = max(1e-6, 1e-4 * scale)
                    
                    # Use central difference for better accuracy
                    for d in range(self.state_dim):
                        # Create perturbed particles
                        particle_plus = particle.copy()
                        particle_plus[d] += eps
                        
                        particle_minus = particle.copy()
                        particle_minus[d] -= eps
                        
                        # Compute likelihoods
                        likelihood_plus = observation_model(particle_plus, observation)
                        likelihood_minus = observation_model(particle_minus, observation)
                        
                        # Handle small likelihoods
                        likelihood_plus = max(likelihood_plus, 1e-15)
                        likelihood_minus = max(likelihood_minus, 1e-15)
                        
                        # Central difference
                        log_derivative = (np.log(likelihood_plus) - np.log(likelihood_minus)) / (2 * eps)
                        scores[i, d] = log_derivative
                except Exception as e:
                    # Keep score at zero for this particle
                    if return_logp:
                        log_probs[i] = -30  # Very low log probability
            
            # Clip extreme values
            max_score = 100.0
            scores = np.clip(scores, -max_score, max_score)
            
            if return_logp:
                return scores, log_probs
            else:
                return scores
        
        return score_fn
    
    def update(self, action, observation, transition_model, observation_model):
        """
        Update belief using robust SVGD.
        
        Args:
            action: Action taken
            observation: Observation received
            transition_model: Function that simulates state transitions
            observation_model: Function that computes observation likelihoods
        """
        # Apply transition model to predict particle movements
        predicted_particles = np.zeros_like(self.particles)
        
        try:
            # Apply transition for each particle
            for i in range(self.n_particles):
                predicted_particles[i] = transition_model(self.particles[i], action)
            
            # Add small noise to prevent collapse
            noise_scale = 0.01 * (1.0 + 0.1 * self.state_dim)
            predicted_particles += np.random.randn(*predicted_particles.shape) * noise_scale
            
            # Fix any NaN or Inf values
            if np.any(~np.isfinite(predicted_particles)):
                warnings.warn("NaN or Inf detected in predicted particles. Fixing.")
                predicted_particles = np.nan_to_num(
                    predicted_particles, 
                    nan=np.mean(self.particles, axis=0),
                    posinf=1e10, 
                    neginf=-1e10
                )
            
            # Create score function
            score_fn = self._compute_score_function(predicted_particles, observation, observation_model)
            
            # Initialize iteration with predicted particles
            particles = predicted_particles.copy()
            
            # Apply SVGD iterations
            for t in range(self.max_iter):
                # Compute update
                update = self._compute_svgd_update(particles, score_fn)
                
                # Apply update
                new_particles = particles + self.step_size * update
                
                # Periodically check for stuck particles
                if t > 0 and t % 10 == 0:
                    new_particles = self._detect_and_resample_particles(new_particles, score_fn)
                
                # Check convergence
                diff = np.linalg.norm(new_particles - particles) / self.n_particles
                
                # Update particles
                particles = new_particles
                
                # Break if converged
                if diff < self.tol:
                    if self.verbose:
                        print(f"SVGD converged after {t+1} iterations")
                    break
            
            # Update belief particles
            self.particles = particles
            
            # Final check for numerical issues
            if np.any(~np.isfinite(self.particles)):
                warnings.warn("NaN or Inf detected in updated particles. Fixing.")
                # Replace problematic particles with predicted ones
                bad_mask = ~np.all(np.isfinite(self.particles), axis=1)
                self.particles[bad_mask] = predicted_particles[bad_mask]
        
        except Exception as e:
            warnings.warn(f"Error in belief update: {e}")
            # In case of error, use predicted particles
            self.particles = predicted_particles
    
    def _build_policy_network(self):
        """
        Build robust policy network with improved belief encoding.
        
        Returns:
            torch.nn.Module: Policy network
        """
        class RobustParticleEncoder(torch.nn.Module):
            def __init__(self, state_dim, hidden_dim, action_dim, discrete_actions):
                super(RobustParticleEncoder, self).__init__()
                
                # Particle encoder with batch normalization for stability
                self.particle_encoder = torch.nn.Sequential(
                    torch.nn.Linear(state_dim, hidden_dim),
                    torch.nn.BatchNorm1d(hidden_dim),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_dim, hidden_dim),
                    torch.nn.BatchNorm1d(hidden_dim),
                    torch.nn.ReLU()
                )
                
                # Belief encoder
                self.belief_encoder = torch.nn.Sequential(
                    torch.nn.Linear(hidden_dim, hidden_dim),
                    torch.nn.LayerNorm(hidden_dim),
                    torch.nn.ReLU()
                )
                
                # Output layers
                if discrete_actions:
                    self.action_head = torch.nn.Linear(hidden_dim, action_dim)
                else:
                    self.action_mean = torch.nn.Linear(hidden_dim, action_dim)
                    self.action_log_std = torch.nn.Linear(hidden_dim, action_dim)
                
                self.discrete_actions = discrete_actions
                self.hidden_dim = hidden_dim
            
            def forward(self, particles):
                """
                Forward pass through the network.
                
                Args:
                    particles: Belief particles of shape [batch_size, n_particles, state_dim]
                
                Returns:
                    Action probabilities or parameters
                """
                batch_size, n_particles, state_dim = particles.shape
                
                # Process each particle
                flat_particles = particles.reshape(-1, state_dim)
                
                # Handle batch norm with flattened batch dimension
                particle_features = self.particle_encoder(flat_particles)
                
                # Reshape back
                particle_features = particle_features.view(batch_size, n_particles, self.hidden_dim)
                
                # Weighted pooling based on particle distances to mean
                # This gives more importance to particles in dense regions
                mean_particle = torch.mean(particle_features, dim=1, keepdim=True)
                distances = torch.sum((particle_features - mean_particle) ** 2, dim=2, keepdim=True)
                weights = torch.softmax(-distances / 10.0, dim=1)
                belief_features = torch.sum(particle_features * weights, dim=1)
                
                # Process belief representation
                belief_encoded = self.belief_encoder(belief_features)
                
                if self.discrete_actions:
                    # For discrete actions
                    action_logits = self.action_head(belief_encoded)
                    return torch.nn.functional.softmax(action_logits, dim=-1)
                else:
                    # For continuous actions
                    action_mean = self.action_mean(belief_encoded)
                    action_log_std = self.action_log_std(belief_encoded)
                    action_log_std = torch.clamp(action_log_std, -20, 2)
                    return action_mean, action_log_std
        
        if self.action_dim is None:
            raise ValueError("Action dimension must be set before building policy network")
            
        return RobustParticleEncoder(
            self.state_dim, 
            self.hidden_dim, 
            self.action_dim, 
            self.discrete_actions
        ).to(self.device if hasattr(self, 'device') else torch.device('cpu'))
    
    def _initialize_policy(self, action_dim):
        """
        Initialize the policy network.
        
        Args:
            action_dim: Dimension of the action space
        """
        self.action_dim = action_dim
        self.policy_network = self._build_policy_network()
        
        # Use AdamW optimizer with weight decay
        self.optimizer = torch.optim.AdamW(
            self.policy_network.parameters(),
            lr=self.learning_rate,
            weight_decay=1e-4
        )
        
        self.policy_initialized = True
        
        if self.verbose:
            print(f"Initialized RobustSVGD policy network with action_dim={action_dim}")
    
    def get_belief_estimate(self):
        """
        Get the current belief particles.
        
        Returns:
            np.ndarray: Current belief particles
        """
        return self.particles
    
    def select_action(self, action_space=None, deterministic=False):
        """
        Select an action based on current belief.
        
        Args:
            action_space: The action space (for initialization)
            deterministic: Whether to select deterministically
            
        Returns:
            Selected action
        """
        # Initialize policy if needed
        if not self.policy_initialized:
            if action_space is None:
                raise ValueError("Action space must be provided on first call to select_action")
            
            # Determine action dimension
            if self.discrete_actions:
                if isinstance(action_space, (list, np.ndarray)):
                    self.action_dim = len(action_space)
                elif hasattr(action_space, 'n'):
                    self.action_dim = action_space.n
                else:
                    raise ValueError("Cannot determine action dimension from provided action space")
            else:
                if hasattr(action_space, 'shape'):
                    self.action_dim = action_space.shape[0]
                else:
                    raise ValueError("Cannot determine action dimension from provided action space")
            
            # Initialize policy
            self._initialize_policy(self.action_dim)
            
            # Set device if not already set
            if not hasattr(self, 'device'):
                self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Get belief particles
        belief_particles = self.particles
        
        # Clean any NaN or Inf
        belief_particles = np.nan_to_num(
            belief_particles, 
            nan=0.0, 
            posinf=1e10, 
            neginf=-1e10
        )
        
        # Convert to tensor
        if len(belief_particles.shape) == 2:
            belief_particles = belief_particles[np.newaxis, :, :]
            
        belief_particles = torch.FloatTensor(belief_particles).to(self.device)
        
        # Pass through policy network
        self.policy_network.eval()
        with torch.no_grad():
            if self.discrete_actions:
                # Discrete actions
                action_probs = self.policy_network(belief_particles)
                
                if deterministic:
                    # Select highest probability
                    action = torch.argmax(action_probs, dim=-1)
                else:
                    # Sample from distribution
                    action_dist = torch.distributions.Categorical(action_probs)
                    action = action_dist.sample()
                
                # Convert to numpy
                action = action.cpu().numpy().squeeze()
            else:
                # Continuous actions
                action_mean, action_log_std = self.policy_network(belief_particles)
                
                if deterministic:
                    # Use mean
                    action = action_mean
                else:
                    # Sample from distribution
                    action_std = torch.exp(action_log_std)
                    action_dist = torch.distributions.Normal(action_mean, action_std)
                    action = action_dist.sample()
                
                # Convert to numpy
                action = action.cpu().numpy().squeeze()
                
                # Clip extreme values
                action = np.clip(action, -10.0, 10.0)
        
        return action
    
    def reset(self, initial_particles=None):
        """
        Reset the belief state.
        
        Args:
            initial_particles: Initial particle states (if None, use random initialization)
        """
        if initial_particles is not None:
            if isinstance(initial_particles, np.ndarray) and initial_particles.shape == (self.n_particles, self.state_dim):
                self.particles = initial_particles.copy()
            else:
                raise ValueError(f"Initial particles must have shape ({self.n_particles}, {self.state_dim})")
        else:
            # Random initialization
            self.particles = self._initialize_particles()


# Visualizing beliefs for debugging
def visualize_belief(particles, true_state=None, title="Belief Distribution"):
    """
    Visualize the belief distribution represented by particles.
    
    Args:
        particles: Particle states of shape (n_particles, dim)
        true_state: Optional ground truth state for reference
        title: Title for the plot
    """
    n_particles, dim = particles.shape
    
    if dim == 1:
        # 1D visualization
        plt.figure(figsize=(10, 4))
        
        # Plot histogram
        plt.hist(particles, bins=30, density=True, alpha=0.7)
        
        # Add kernel density estimate
        from scipy.stats import gaussian_kde
        kde = gaussian_kde(particles.flatten())
        x = np.linspace(np.min(particles), np.max(particles), 200)
        plt.plot(x, kde(x), 'r-', linewidth=2)
        
        # Add true state if provided
        if true_state is not None:
            plt.axvline(x=true_state[0], color='g', linestyle='--', linewidth=2, label='True State')
            plt.legend()
        
        plt.title(title)
        plt.xlabel("State")
        plt.ylabel("Density")
        plt.grid(True, alpha=0.3)
        
    elif dim == 2:
        # 2D visualization
        plt.figure(figsize=(8, 8))
        
        # Scatter plot of particles
        plt.scatter(particles[:, 0], particles[:, 1], alpha=0.5, s=10)
        
        # Add contour plot if enough particles
        if n_particles >= 20:
            try:
                from scipy.stats import gaussian_kde
                # Create meshgrid
                x_min, x_max = np.min(particles[:, 0]), np.max(particles[:, 0])
                y_min, y_max = np.min(particles[:, 1]), np.max(particles[:, 1])
                
                # Add padding
                padding = 0.1 * max(x_max - x_min, y_max - y_min)
                x_min -= padding
                x_max += padding
                y_min -= padding
                y_max += padding
                
                xx, yy = np.mgrid[x_min:x_max:100j, y_min:y_max:100j]
                positions = np.vstack([xx.ravel(), yy.ravel()])
                
                # Compute KDE
                kde = gaussian_kde(particles.T)
                f = np.reshape(kde(positions), xx.shape)
                
                # Plot contours
                plt.contour(xx, yy, f, cmap='Blues', alpha=0.8)
            except Exception as e:
                warnings.warn(f"Error creating contour plot: {e}")
        
        # Add true state if provided
        if true_state is not None:
            plt.plot(true_state[0], true_state[1], 'g*', markersize=15, label='True State')
            plt.legend()
        
        plt.title(title)
        plt.xlabel("X")
        plt.ylabel("Y")
        plt.grid(True, alpha=0.3)
        plt.axis('equal')
        
    else:
        # For higher dimensions, show 2D projections of first two dimensions
        plt.figure(figsize=(8, 8))
        
        # Scatter plot of particles (first two dimensions)
        plt.scatter(particles[:, 0], particles[:, 1], alpha=0.5, s=10, label='Dimensions 1-2')
        
        # Add true state if provided
        if true_state is not None:
            plt.plot(true_state[0], true_state[1], 'g*', markersize=15, label='True State (Dims 1-2)')
        
        # Also show dimensions 3-4 if available
        if dim >= 4:
            plt.scatter(particles[:, 2], particles[:, 3], alpha=0.3, s=10, marker='x', label='Dimensions 3-4')
            if true_state is not None:
                plt.plot(true_state[2], true_state[3], 'r*', markersize=15, label='True State (Dims 3-4)')
        
        plt.title(f"{title} (Projection of First Dimensions)")
        plt.xlabel("Dimension 1/3")
        plt.ylabel("Dimension 2/4")
        plt.grid(True, alpha=0.3)
        plt.legend()
    
    plt.tight_layout()
    plt.show()


# Profile and debug belief updates
def profile_belief_update(algorithm, transition_model, observation_model, 
                         true_state=None, n_updates=5, visualize=True):
    """
    Profile and debug belief updates for a given algorithm.
    
    Args:
        algorithm: The belief update algorithm (ESCORT or SVGD)
        transition_model: Function that models transitions
        observation_model: Function that computes likelihoods
        true_state: Optional ground truth state for reference
        n_updates: Number of updates to perform
        visualize: Whether to visualize beliefs
    """
    # Reset algorithm
    algorithm.reset()
    
    # Get initial particles
    particles = algorithm.get_belief_estimate()
    dim = particles.shape[1]
    
    print(f"Initial particles shape: {particles.shape}")
    
    if visualize:
        visualize_belief(particles, true_state, "Initial Belief")
    
    # Dummy action (depends on action space)
    action = 0
    
    # Generate observations from true state if provided
    observations = []
    if true_state is not None:
        for _ in range(n_updates):
            # Simulate transition
            true_state = transition_model(true_state, action)
            
            # Generate observation
            try:
                # If observation_model can generate observations
                observation = observation_model.sample(true_state)
            except:
                # Otherwise use true state with noise
                observation = true_state + np.random.normal(0, 0.1, size=dim)
            
            observations.append(observation)
    else:
        # Random observations
        observations = [np.random.rand(dim) for _ in range(n_updates)]
    
    # Perform updates
    update_times = []
    for i in range(n_updates):
        print(f"\nUpdate {i+1}/{n_updates}")
        
        # Time the update
        start_time = time.time()
        algorithm.update(action, observations[i], transition_model, observation_model)
        end_time = time.time()
        
        update_time = end_time - start_time
        update_times.append(update_time)
        print(f"Update time: {update_time:.4f} seconds")
        
        # Get updated particles
        particles = algorithm.get_belief_estimate()
        
        # Check for NaN or Inf
        if np.any(~np.isfinite(particles)):
            print("WARNING: NaN or Inf values detected in particles!")
            print(f"Number of particles with issues: {np.sum(~np.all(np.isfinite(particles), axis=1))}")
        
        # Print statistics
        means = np.mean(particles, axis=0)
        stds = np.std(particles, axis=0)
        
        print(f"Particle means: {means}")
        print(f"Particle stds: {stds}")
        
        if true_state is not None:
            mse = np.mean((means - true_state)**2)
            print(f"MSE to true state: {mse:.4f}")
        
        if visualize:
            visualize_belief(particles, true_state, f"Belief After Update {i+1}")
    
    print(f"\nAverage update time: {np.mean(update_times):.4f} seconds")
