import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import pandas as pd
from tqdm import tqdm
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
import os
import time
import traceback

# Import the Light-Dark environment
from lightdark10d.lightdark10d_env import LightDark10DEnv

class LightDark10DRunner:
    """
    # Runner for the 10D Light-Dark POMDP problem that evaluates different
    # belief approximation methods on various metrics.
    
    This runner handles:
    1. Environment creation and episode progression
    2. Belief update for each method
    3. Evaluation using multiple metrics
    4. Visualization of beliefs and performances
    """
    
    def __init__(self, env_params=None, save_dir="results_ld"):
        """
        Initialize the runner with environment parameters.
        
        Args:
            env_params: Dictionary of parameters for the LightDark10DEnv
            save_dir: Directory to save results and visualizations
        """
        # Default environment parameters
        self.env_params = {
            "map_size": 10,
            "noise_level": 0.5,
        }
        
        # Update with user-provided parameters
        if env_params is not None:
            self.env_params.update(env_params)
        
        # Create environment
        self.env = LightDark10DEnv(**self.env_params)
        
        # Storage for results
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        os.makedirs(os.path.join(save_dir, "plots"), exist_ok=True)
        
        # State dimension
        self.state_dim = 10  # Fixed for this environment
        
        # Reference to ground truth state for evaluation
        self.true_state = None
    
    def transition_model(self, state, action):
        """
        Transition model for particle belief updates.
        
        Args:
            state: Current state particle
            action: Action taken
            
        Returns:
            Updated state particle
        """
        # Extract current state components
        position = state[:5]
        velocity = state[5:]
        
        # Initialize state delta
        state_delta = np.zeros(10)
        
        # Apply force based on action
        force = np.zeros(5)
        if action < 10:
            dim = action // 2
            direction = 1 if action % 2 == 0 else -1
            force[dim] = direction * 0.1
        
        # Update velocity based on force with some damping
        damping = 0.1
        velocity_delta = force - damping * velocity
        state_delta[5:] = velocity_delta
        
        # Update position based on velocity
        dt = 0.1
        position_delta = velocity * dt
        state_delta[:5] = position_delta
        
        # Apply correlation structure similar to environment
        scales = np.abs(state_delta) + 0.01
        corr_matrix = self._get_correlation_matrix()
        cov_matrix = np.outer(scales, scales) * corr_matrix
        
        # Generate correlated noise
        correlated_noise = np.random.multivariate_normal(
            mean=np.zeros(self.state_dim), cov=cov_matrix)
        
        # Scale the noise and add to state
        direction = np.sign(state_delta)
        magnitude = np.abs(state_delta)
        correlated_delta = direction * (magnitude + 0.1 * correlated_noise)
        
        # Apply correlated update
        new_state = state + correlated_delta
        
        # Ensure the agent stays within the map boundaries
        new_state[:5] = np.clip(new_state[:5], 0, self.env_params["map_size"])
        
        return new_state
    
    def observation_model(self, state, observation):
        """
        Observation likelihood model for particle belief updates.
        
        Args:
            state: Current state particle
            observation: Observation received
            
        Returns:
            Likelihood of the observation given the state
        """
        # Extract position from state
        position = state[:5]
        
        # Get light level at this position
        light_level = self._get_light_level(position)
        
        # Calculate observation noise level based on light
        noise_scale = self.env_params["noise_level"] * (1.0 - light_level) + 0.01
        
        # Expected observation is just the position in this environment
        expected_obs = position
        
        # Compute likelihood using Gaussian noise model
        if observation is None:
            # If no observation provided, just generate one
            noise = np.random.normal(0, noise_scale, 5)
            return expected_obs + noise
        else:
            # Compute probability of observation given the state
            likelihood = 1.0
            
            for i in range(5):
                # Calculate likelihood for each dimension
                diff = observation[i] - expected_obs[i]
                dim_likelihood = np.exp(-0.5 * (diff / noise_scale)**2) / (noise_scale * np.sqrt(2 * np.pi))
                likelihood *= dim_likelihood
            
            return likelihood
    
    def _get_light_level(self, position):
        """
        Get light level at a given position.
        This should match the environment's light level calculation.
        
        Args:
            position: 5D position vector
            
        Returns:
            Light level (0-1, higher is brighter)
        """
        # Default dark level
        light_level = 0.05
        
        # Check each light region
        for center, radius, intensity in self.env.light_regions:
            # Calculate distance to light region center (only in relevant dimensions)
            relevant_dims = center != 0
            if np.sum(relevant_dims) > 0:
                # Only calculate distance in dimensions where center != 0
                pos_subset = position[relevant_dims]
                center_subset = center[relevant_dims]
                
                # Scaled Euclidean distance
                distance = np.linalg.norm(pos_subset - center_subset) / np.sqrt(np.sum(relevant_dims))
            else:
                # If all center coordinates are 0, use full Euclidean distance
                distance = np.linalg.norm(position - center)
            
            # If within light region
            if distance < radius:
                # Calculate light contribution based on distance from center
                contribution = intensity * (1.0 - (distance / radius) ** 2)
                
                # Take maximum light level from all contributing regions
                light_level = max(light_level, contribution)
        
        return light_level
    
    def _get_correlation_matrix(self):
        """
        Get correlation matrix for the state variables.
        
        Returns:
            10x10 correlation matrix
        """
        # This should match the environment's correlation matrix
        # Initialize with identity matrix (no correlation)
        corr_matrix = np.eye(10)
        
        # Correlation between position and velocity in same dimension
        for i in range(5):
            corr_matrix[i, i+5] = corr_matrix[i+5, i] = 0.8
        
        # Correlation between adjacent positional dimensions
        for i in range(4):
            corr_matrix[i, i+1] = corr_matrix[i+1, i] = 0.5
        
        # Correlation between adjacent velocity dimensions
        for i in range(5, 9):
            corr_matrix[i, i+1] = corr_matrix[i+1, i] = 0.6
        
        # Non-obvious correlation between dimensions 1-3 and 2-4
        corr_matrix[0, 2] = corr_matrix[2, 0] = 0.4
        corr_matrix[1, 3] = corr_matrix[3, 1] = 0.4
        
        # Velocity coupling: v1 affects v2 and v3
        corr_matrix[5, 6] = corr_matrix[6, 5] = 0.7
        corr_matrix[5, 7] = corr_matrix[7, 5] = 0.5
        
        return corr_matrix
    
    def run_experiment(self, methods, n_episodes=10, max_steps=100, n_particles=100):
        """
        Run experiment to compare different belief approximation methods.
        
        Args:
            methods: Dictionary mapping method names to implementation objects
            n_episodes: Number of episodes to run
            max_steps: Maximum steps per episode
            n_particles: Number of particles for belief representation
            
        Returns:
            DataFrame with evaluation results
        """
        # Initialize results dictionary
        results = {
            "Method": [],
            "Episode": [],
            "Steps": [],
            "Final Position Error": [],
            "Mean Position Error": [],
            "Max Position Error": [],
            "MMD": [],
            "Sliced Wasserstein": [],
            "Correlation Error": [],
            "Mode Coverage": [],
            "ESS": [],
            "Runtime": [],
            "Success": [],
            "Final Distance": []
        }
        
        # Storage for belief particles
        all_particles = {method_name: [] for method_name in methods.keys()}
        ground_truth_states = []
        
        # Run episodes
        for episode in range(n_episodes):
            print(f"Episode {episode+1}/{n_episodes}")
            
            # Reset environment
            obs = self.env.reset()
            self.true_state = self.env.state.copy()
            
            # Initialize particles randomly for each method
            particles = {}
            for method_name in methods.keys():
                # Initialize particles with appropriate ranges
                init_particles = np.zeros((n_particles, self.state_dim))
                
                # Position (uniformly distributed in map)
                init_particles[:, :5] = np.random.uniform(
                    0, self.env_params["map_size"], (n_particles, 5))
                
                # Velocity (normal distribution around zero)
                init_particles[:, 5:] = np.random.normal(0, 0.1, (n_particles, 5))
                
                particles[method_name] = init_particles
            
            # Run episode
            for step in range(max_steps):
                # Save ground truth state for evaluation
                ground_truth_states.append(self.true_state.copy())
                
                # Select action based on current belief - here using a simple heuristic
                # In a real application, you'd use a policy trained on the belief
                action = self._select_action(self.true_state)
                
                # Take step in environment
                next_obs, reward, done, info = self.env.step(action)
                self.true_state = self.env.state.copy()  # Update true state reference
                
                # Update belief for each method
                for method_name, method in methods.items():
                    start_time = time.time()
                    
                    # Update belief using the method
                    try:
                        # Different methods have different interfaces
                        if hasattr(method, 'update'):
                            # Standard update interface
                            method.update(action, next_obs, self.transition_model, self.observation_model)
                            updated_particles = method.get_belief_estimate()
                        elif hasattr(method, 'fit_transform'):
                            # Some methods use fit_transform pattern
                            updated_particles = method.fit_transform(
                                particles[method_name],
                                lambda x: self.observation_model(x, next_obs),
                                None
                            )
                        else:
                            # Fallback: assume callable method
                            updated_particles = method(
                                particles[method_name],
                                lambda x: self.observation_model(x, next_obs)
                            )
                        
                        # Update particles
                        particles[method_name] = updated_particles
                    except Exception as e:
                        print(f"Error updating {method_name}: {e}")
                        traceback.print_exc()
                        # Keep previous particles if update fails
                    
                    # Record runtime
                    runtime = time.time() - start_time
                    
                    # Store belief for visualization
                    all_particles[method_name].append(particles[method_name].copy())
                
                # Visualize beliefs periodically
                if step % 10 == 0 or step == max_steps - 1:
                    self._visualize_beliefs(particles, episode, step)
                
                # Update observation
                obs = next_obs
                
                if done:
                    break
            
            # Evaluate performance for each method
            for method_name, method_particles in particles.items():
                # Compute metrics
                position_error = self._compute_position_error(method_particles, self.true_state)
                belief_metrics = self._evaluate_belief_quality(method_particles, self.true_state)
                
                # Check success - if we've reached the goal
                distance_to_goal = np.linalg.norm(self.true_state[:5] - self.env.goal)
                success = distance_to_goal < 0.5
                
                # Record results
                results["Method"].append(method_name)
                results["Episode"].append(episode)
                results["Steps"].append(step + 1)
                results["Final Position Error"].append(position_error["final"])
                results["Mean Position Error"].append(position_error["mean"])
                results["Max Position Error"].append(position_error["max"])
                results["MMD"].append(belief_metrics["mmd"])
                results["Sliced Wasserstein"].append(belief_metrics["sliced_wasserstein"])
                results["Correlation Error"].append(belief_metrics["correlation_error"])
                results["Mode Coverage"].append(belief_metrics["mode_coverage"])
                results["ESS"].append(belief_metrics["ess"])
                results["Runtime"].append(belief_metrics["runtime"])
                results["Success"].append(success)
                results["Final Distance"].append(distance_to_goal)
                
                print(f"{method_name}: Final Pos Error = {position_error['final']:.2f}, "
                      f"MMD = {belief_metrics['mmd']:.4f}, "
                      f"Mode Coverage = {belief_metrics['mode_coverage']:.2f}, "
                      f"Success = {success}")
        
        # Convert results to DataFrame
        results_df = pd.DataFrame(results)
        
        # Save results
        results_df.to_csv(os.path.join(self.save_dir, "belief_results.csv"), index=False)
        
        # Generate summary visualizations
        self._visualize_results_summary(results_df)
        
        # Optionally, save particles for further analysis
        np.save(os.path.join(self.save_dir, "ground_truth_states.npy"), np.array(ground_truth_states))
        for method_name, method_particles in all_particles.items():
            np.save(os.path.join(self.save_dir, f"{method_name}_particles.npy"), 
                   np.array(method_particles))
        
        return results_df
    
    def _select_action(self, state):
        """
        Simple policy to select action based on current state.
        This is just a heuristic - in practice, you'd train a policy.
        
        Args:
            state: Current state
            
        Returns:
            Action index
        """
        # Extract position and velocity
        position = state[:5]
        velocity = state[5:]
        
        # Calculate direction to goal
        goal_direction = self.env.goal - position
        
        # Find dimension with largest difference to goal
        max_dim = np.argmax(np.abs(goal_direction))
        
        # Determine if we need to increase or decrease in that dimension
        if goal_direction[max_dim] > 0:
            # Need to increase - action is 2*dim
            return 2 * max_dim
        else:
            # Need to decrease - action is 2*dim + 1
            return 2 * max_dim + 1
    
    def _compute_position_error(self, particles, true_state):
        """
        Compute position estimation error from belief particles.
        
        Args:
            particles: Belief particles
            true_state: True robot state
            
        Returns:
            Dictionary with error metrics
        """
        # Extract position components (first 5 dimensions)
        true_pos = true_state[:5]
        particle_positions = particles[:, :5]
        
        # Compute weighted mean position from particles
        mean_pos = np.mean(particle_positions, axis=0)
        
        # Compute error
        final_error = np.linalg.norm(true_pos - mean_pos)
        
        # Compute errors for all particles
        all_errors = np.linalg.norm(particle_positions - true_pos, axis=1)
        mean_error = np.mean(all_errors)
        max_error = np.max(all_errors)
        
        return {
            "final": final_error,
            "mean": mean_error,
            "max": max_error
        }
    
    def _evaluate_belief_quality(self, particles, true_state):
        """Evaluate quality of belief approximation using multiple metrics"""
        # The evaluation code is very similar to your KidnappedRobotRunner
        # I'll implement the key metrics for this environment
        
        # Compute MMD with proxy samples around true state
        n_reference = len(particles)
        corr_matrix = self._get_correlation_matrix()
        proxy_samples = np.random.multivariate_normal(
            mean=true_state, 
            cov=0.1 * corr_matrix,  # Scale for reasonable spread
            size=n_reference
        )
        
        # Ensure proxy samples respect constraints
        proxy_samples[:, :5] = np.clip(proxy_samples[:, :5], 0, self.env_params["map_size"])
        
        # Calculate Maximum Mean Discrepancy
        mmd = self._compute_mmd(particles, proxy_samples)
        
        # Calculate Sliced Wasserstein Distance
        sliced_wasserstein = self._compute_sliced_wasserstein(particles, proxy_samples)
        
        # Calculate Correlation Error
        correlation_error = self._compute_correlation_error(particles, true_state)
        
        # Calculate Mode Coverage
        mode_coverage = self._compute_mode_coverage(particles, true_state)
        
        # Calculate Effective Sample Size
        ess = self._compute_ess(particles, true_state)
        
        # Placeholder for runtime
        runtime = 0.0
        
        return {
            "mmd": mmd,
            "sliced_wasserstein": sliced_wasserstein,
            "correlation_error": correlation_error,
            "mode_coverage": mode_coverage,
            "ess": ess,
            "runtime": runtime
        }
    
    def _compute_mmd(self, particles, reference_samples, bandwidth=None):
        """
        Compute Maximum Mean Discrepancy between particles and reference samples.
        
        Args:
            particles: Belief particles
            reference_samples: Reference distribution samples
            bandwidth: Kernel bandwidth (if None, median heuristic is used)
            
        Returns:
            MMD value
        """
        # Use median heuristic if bandwidth not provided
        if bandwidth is None:
            # Compute pairwise distances for a subset of particles
            n_subset = min(100, len(particles))
            subset_p = particles[:n_subset]
            
            dists = []
            for i in range(min(20, len(subset_p))):
                xi = subset_p[i]
                diff = subset_p - xi
                dists.extend(np.sum(diff**2, axis=1).tolist())
                
            bandwidth = np.median(dists) if dists else 1.0
        
        # RBF kernel function
        def kernel(x, y):
            # Add this safety check to ensure bandwidth is never zero
            safe_bandwidth = max(bandwidth, 1e-10)
            # Compute squared distance
            squared_dist = np.sum((x - y)**2)
            # Apply exponential with safety measures
            return np.exp(-squared_dist / safe_bandwidth)
        
        # Compute MMD terms
        n_p = min(500, len(particles))  # Limit computation for efficiency
        n_r = min(500, len(reference_samples))
        
        # Sample if needed
        if len(particles) > n_p:
            p_indices = np.random.choice(len(particles), n_p, replace=False)
            particles_sub = particles[p_indices]
        else:
            particles_sub = particles
            
        if len(reference_samples) > n_r:
            r_indices = np.random.choice(len(reference_samples), n_r, replace=False)
            reference_sub = reference_samples[r_indices]
        else:
            reference_sub = reference_samples
        
        # Compute terms
        pp_sum = 0
        for i in range(n_p):
            for j in range(i+1, n_p):
                pp_sum += kernel(particles_sub[i], particles_sub[j])
        pp_sum = 2 * pp_sum / (n_p * (n_p - 1)) if n_p > 1 else 0
        
        rr_sum = 0
        for i in range(n_r):
            for j in range(i+1, n_r):
                rr_sum += kernel(reference_sub[i], reference_sub[j])
        rr_sum = 2 * rr_sum / (n_r * (n_r - 1)) if n_r > 1 else 0
        
        pr_sum = 0
        for i in range(n_p):
            for j in range(n_r):
                pr_sum += kernel(particles_sub[i], reference_sub[j])
        pr_sum = pr_sum / (n_p * n_r)
        
        mmd = pp_sum + rr_sum - 2 * pr_sum
        return max(0, mmd)  # Ensure non-negative
    
    def _compute_sliced_wasserstein(self, particles, reference_samples, n_projections=20):
        """Compute Sliced Wasserstein Distance between particles and reference samples"""
        try:
            # Generate random projection directions
            directions = np.random.randn(n_projections, self.state_dim)
            directions = directions / np.linalg.norm(directions, axis=1, keepdims=True)
            
            # Limit sample size for efficiency
            n_p = min(500, len(particles))
            n_r = min(500, len(reference_samples))
            
            # Sample if needed
            if len(particles) > n_p:
                p_indices = np.random.choice(len(particles), n_p, replace=False)
                particles_sub = particles[p_indices]
            else:
                particles_sub = particles
                
            if len(reference_samples) > n_r:
                r_indices = np.random.choice(len(reference_samples), n_r, replace=False)
                reference_sub = reference_samples[r_indices]
            else:
                reference_sub = reference_samples
            
            # Compute Sliced Wasserstein Distance
            swd = 0.0
            
            for direction in directions:
                # Project samples onto this direction
                particles_proj = particles_sub @ direction
                reference_proj = reference_sub @ direction
                
                # Sort projections
                particles_proj = np.sort(particles_proj)
                reference_proj = np.sort(reference_proj)
                
                # Compute 1-Wasserstein distance for this projection
                if len(particles_proj) != len(reference_proj):
                    # Interpolate to match lengths
                    if len(particles_proj) > len(reference_proj):
                        indices = np.linspace(0, len(reference_proj)-1, len(particles_proj))
                        reference_proj_interp = np.interp(indices, np.arange(len(reference_proj)), reference_proj)
                        w_dist = np.mean(np.abs(particles_proj - reference_proj_interp))
                    else:
                        indices = np.linspace(0, len(particles_proj)-1, len(reference_proj))
                        particles_proj_interp = np.interp(indices, np.arange(len(particles_proj)), particles_proj)
                        w_dist = np.mean(np.abs(particles_proj_interp - reference_proj))
                else:
                    w_dist = np.mean(np.abs(particles_proj - reference_proj))
                
                swd += w_dist
            
            return swd / n_projections
        except Exception as e:
            print(f"Error computing sliced Wasserstein distance: {e}")
            return float('inf')
    
    def _compute_correlation_error(self, particles, true_state):
        """Compute error in capturing correlation structure"""
        # Compute empirical covariance from particles
        if len(particles) < 10:  # Need enough particles
            return 1.0
            
        empirical_cov = np.cov(particles, rowvar=False)
        
        # Get true correlation matrix
        true_corr = self._get_correlation_matrix()
        
        # Normalize empirical covariance to correlation
        d = np.sqrt(np.diag(empirical_cov))
        d_mat = np.outer(d, d)
        # Avoid division by zero
        d_mat[d_mat < 1e-10] = 1.0
        empirical_corr = empirical_cov / d_mat
        
        # Compute Frobenius norm of the difference
        diff_norm = np.linalg.norm(empirical_corr - true_corr, 'fro')
        true_norm = np.linalg.norm(true_corr, 'fro')
        
        # Normalize the error
        if true_norm > 1e-10:
            return diff_norm / true_norm
        else:
            return 1.0
    
    def _compute_mode_coverage(self, particles, true_state):
        """Compute mode coverage ratio"""
        # For this task, we consider position modes (first 5 dimensions)
        pos_particles = particles[:, :5]
        
        # Cluster to find modes
        clustering = DBSCAN(eps=1.0, min_samples=3).fit(pos_particles)
        labels = clustering.labels_
        
        # Number of detected modes
        n_modes = len(set(labels)) - (1 if -1 in labels else 0)
        
        # Check if true position is covered
        true_pos = true_state[:5]
        
        # Find distances from true position to each particle
        dists = np.sqrt(np.sum((pos_particles - true_pos)**2, axis=1))
        
        # Check if at least some particles are close to true position
        close_particles = np.sum(dists < 1.5)  # Within 1.5 units
        
        # At least 3% of particles should be close to true position
        true_pos_covered = close_particles >= max(3, 0.03 * len(particles))
        
        # For this environment, multi-modality is expected in dark regions
        # Check the light level at true position
        light_level = self._get_light_level(true_pos)
        
        # In dark regions, multiple modes are expected
        if light_level < 0.3:
            if n_modes >= 2 and true_pos_covered:
                return 1.0
            elif true_pos_covered:
                return 0.7
            else:
                return 0.0
        else:
            # In light regions, fewer modes are expected
            if true_pos_covered and n_modes <= 2:
                return 1.0
            elif true_pos_covered:
                return 0.8  # Penalize excessive modes in light regions
            else:
                return 0.0
    
    def _compute_ess(self, particles, true_state):
        """Compute Effective Sample Size ratio"""
        # Generate expected observation from true state
        expected_obs = self.observation_model(true_state, None)
        
        # Compute observation likelihood for each particle
        log_weights = np.zeros(len(particles))
        
        for i, particle in enumerate(particles):
            likelihood = self.observation_model(particle, expected_obs)
            log_weights[i] = np.log(max(likelihood, 1e-10))
        
        # Normalize log weights
        log_weights -= np.max(log_weights)
        weights = np.exp(log_weights)
        sum_weights = np.sum(weights)
        
        if sum_weights < 1e-10:
            return 0.0
            
        weights /= sum_weights
        
        # Compute ESS
        ess = 1.0 / np.sum(weights**2)
        
        # Normalize by number of particles
        return ess / len(particles)
    
    def _visualize_beliefs(self, particles_dict, episode, step):
        """Visualize current belief particles for all methods"""
        n_methods = len(particles_dict)
        
        # Create a single figure
        fig = plt.figure(figsize=(15, 5 * n_methods))
        fig.suptitle(f"Episode {episode+1}, Step {step+1} - Belief Visualization", fontsize=16)
        
        # Create a grid of subplots
        gridspec = fig.add_gridspec(n_methods, 4)
        
        for i, (method_name, particles) in enumerate(particles_dict.items()):
            # Create a row header for this method
            row_title = fig.add_subplot(gridspec[i, 0])
            row_title.text(0.5, 0.5, f"{method_name}", fontsize=14, 
                        ha='center', va='center')
            row_title.axis('off')
            
            # Define the projections to show
            projections = [(0, 1), (2, 3), (0, 4)]
            
            for j, (dim1, dim2) in enumerate(projections):
                ax = fig.add_subplot(gridspec[i, j+1])
                
                # Plot the map boundaries
                ax.add_patch(Rectangle((0, 0), self.env_params["map_size"], 
                                    self.env_params["map_size"], 
                                    fill=False, edgecolor='black'))
                
                # Plot light regions (projections to 2D)
                light_map = np.zeros((20, 20))
                x_grid = np.linspace(0, self.env_params["map_size"], 20)
                y_grid = np.linspace(0, self.env_params["map_size"], 20)
                
                for x_idx, x in enumerate(x_grid):
                    for y_idx, y in enumerate(y_grid):
                        # Create a 5D position with the current 2D values
                        pos_5d = np.zeros(5)
                        pos_5d[dim1] = x
                        pos_5d[dim2] = y
                        
                        # Get light level
                        light_map[x_idx, y_idx] = self._get_light_level(pos_5d)
                
                # Plot light map
                im = ax.imshow(light_map.T, extent=[0, self.env_params["map_size"], 
                                                0, self.env_params["map_size"]], 
                            origin='lower', cmap='YlGnBu', alpha=0.5)
                
                # Plot the goal position (projection)
                ax.scatter(self.env.goal[dim1], self.env.goal[dim2], c='green', marker='*', 
                        s=100, label='Goal')
                
                # Plot the agent's position (projection)
                ax.scatter(self.true_state[dim1], self.true_state[dim2], c='red', marker='o', 
                        s=50, label='Agent')
                
                # Add velocity vector
                arrow_scale = 1.0
                ax.arrow(self.true_state[dim1], self.true_state[dim2], 
                        self.true_state[dim1+5] * arrow_scale, 
                        self.true_state[dim2+5] * arrow_scale,
                        head_width=0.2, head_length=0.3, fc='red', ec='red')
                
                # Plot particles (projections)
                ax.scatter(particles[:, dim1], particles[:, dim2], c='orange', marker='.', 
                        s=10, alpha=0.5, label='Belief')
                
                # Create a 2D histogram of particle positions for density visualization
                try:
                    heatmap, xedges, yedges = np.histogram2d(
                        particles[:, dim1], particles[:, dim2], bins=20,
                        range=[[0, self.env_params["map_size"]], 
                            [0, self.env_params["map_size"]]])
                    
                    # Plot heatmap with transparency
                    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
                    ax.imshow(heatmap.T, extent=extent, origin='lower', 
                            cmap='YlOrRd', alpha=0.3, interpolation='bilinear')
                except:
                    pass  # Skip heatmap if there's an error
                
                # Set plot labels and title
                ax.set_xlim(0, self.env_params["map_size"])
                ax.set_ylim(0, self.env_params["map_size"])
                ax.set_xlabel(f'Dim {dim1+1}')
                ax.set_ylabel(f'Dim {dim2+1}')
                ax.set_title(f'Dims {dim1+1}-{dim2+1}')
                
                # Add grid
                ax.grid(True)
                
                # Only add legend for the first subplot
                if j == 0:
                    ax.legend(loc='upper right')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "plots", 
                                f"belief_ep{episode+1}_step{step+1}.png"))
        plt.close(fig)
    
    def _visualize_results_summary(self, results_df):
        """
        Visualize summary of experiment results.
        
        Args:
            results_df: DataFrame with evaluation results
        """
        # Aggregate results across episodes
        summary = results_df.groupby('Method').mean().reset_index()
        
        # Plot metrics comparison
        metrics = [
            'Final Position Error', 'MMD', 'Sliced Wasserstein', 
            'Correlation Error', 'Mode Coverage', 'ESS', 'Success'
        ]
        
        fig = plt.figure(figsize=(15, 10))
        fig.suptitle("Performance Comparison of Belief Approximation Methods", fontsize=16)
        
        for i, metric in enumerate(metrics):
            ax = fig.add_subplot(2, 4, i+1)
            
            # Sort by metric (asc or desc depending on metric)
            if metric in ['Mode Coverage', 'ESS', 'Success']:
                # Higher is better
                sorted_df = summary.sort_values(metric, ascending=False)
            else:
                # Lower is better
                sorted_df = summary.sort_values(metric)
            
            # Create bar plot
            bars = ax.bar(sorted_df['Method'], sorted_df[metric])
            
            # Add value annotations
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height, 
                       f'{height:.3f}', ha='center', va='bottom', rotation=0)
            
            # Set labels and title
            ax.set_xlabel('Method')
            ax.set_ylabel(metric)
            
            if metric in ['Mode Coverage', 'ESS', 'Success']:
                ax.set_title(f"{metric} (higher is better)")
                if metric == 'Success':
                    ax.set_ylim(0, 1.1)
            else:
                ax.set_title(f"{metric} (lower is better)")
                
            ax.grid(axis='y', alpha=0.3)
            plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "metrics_comparison.png"))
        plt.close(fig)
        
        # Create combined metrics plot
        # Normalize metrics for combined view
        normalized_df = summary.copy()
        
        for metric in metrics:
            if metric in ['Mode Coverage', 'ESS', 'Success']:
                # Higher is better, already normalized
                normalized_df[f"{metric}_norm"] = normalized_df[metric]
            else:
                # Lower is better, invert normalization
                if normalized_df[metric].max() != normalized_df[metric].min():
                    normalized_df[f"{metric}_norm"] = 1 - (normalized_df[metric] - normalized_df[metric].min()) / (normalized_df[metric].max() - normalized_df[metric].min())
                else:
                    normalized_df[f"{metric}_norm"] = 1.0
        
        # Plot radar chart for each method
        fig = plt.figure(figsize=(15, 10))
        fig.suptitle("Normalized Performance Metrics (higher is better)", fontsize=16)
        
        # Number of metrics
        N = len(metrics)
        
        # Angle of each axis
        angles = np.linspace(0, 2*np.pi, N, endpoint=False).tolist()
        angles += angles[:1]  # Close the loop
        
        # Position for each metric label
        label_angles = np.linspace(0, 2*np.pi, N, endpoint=False)
        label_positions = []
        for angle in label_angles:
            label_positions.append((np.cos(angle), np.sin(angle)))
        
        # Plot each method as a separate radar chart
        n_methods = len(normalized_df)
        n_cols = min(3, n_methods)
        n_rows = (n_methods + n_cols - 1) // n_cols
        
        for i, (_, row) in enumerate(normalized_df.iterrows()):
            ax = fig.add_subplot(n_rows, n_cols, i+1, projection='polar')
            
            # Plot metrics
            values = [row[f"{metric}_norm"] for metric in metrics]
            values += values[:1]  # Close the loop
            
            ax.plot(angles, values, linewidth=2, linestyle='solid')
            ax.fill(angles, values, alpha=0.25)
            
            # Set labels
            ax.set_xticks(angles[:-1])
            ax.set_xticklabels(metrics)
            
            # Adjust label positions for readability
            for label, angle in zip(ax.get_xticklabels(), label_angles):
                x, y = label.get_position()
                lab = ax.get_xticklabels()[list(label_angles).index(angle)]
                lab.set_horizontalalignment('center')
                lab.set_verticalalignment('center')
                lab.set_rotation(np.degrees(angle))
            
            ax.set_ylim(0, 1)
            ax.set_title(row['Method'])
            
            # Add gridlines
            ax.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "radar_comparison.png"))
        plt.close(fig)