import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, Ellipse
import pandas as pd
from tqdm import tqdm
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
import os
import time
import traceback

# Import the robot environment
from kidnapped.kidnapped_robot_env import KidnappedRobotEnv

class KidnappedRobotRunner:
    """
    # Runner for the Kidnapped Robot POMDP problem that evaluates different
    # belief approximation methods on various metrics.
    
    This runner handles:
    1. Environment creation and episode progression
    2. Belief update for each method
    3. Evaluation using multiple metrics
    4. Visualization of beliefs and performances
    """
    
    def __init__(self, env_params=None, save_dir="results"):
        """
        Initialize the runner with environment parameters.
        
        Args:
            env_params: Dictionary of parameters for the KidnappedRobotEnv
            save_dir: Directory to save results and visualizations
        """
        # Default environment parameters
        self.env_params = {
            "map_size": 20,
            "n_landmarks": 15,
            "sensor_range": 5,
            "noise_level": 0.1,
        }
        
        # Update with user-provided parameters
        if env_params is not None:
            self.env_params.update(env_params)
        
        # Create environment
        self.env = KidnappedRobotEnv(**self.env_params)
        
        # Storage for results
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        os.makedirs(os.path.join(save_dir, "plots"), exist_ok=True)
        
        # State dimension
        self.state_dim = 20  # Fixed for this environment
        
        # Reference to ground truth state for evaluation
        self.true_state = None
    
    def transition_model(self, state, action):
        """
        Transition model for particle belief updates.
        
        Args:
            state: Current state particle
            action: Action taken
            
        Returns:
            Updated state particle
        """
        # Extract state components
        pos_x, pos_y = state[0], state[1]
        orientation = state[2]
        velocity, steering = state[3], state[4]
        
        # State update based on action
        # This should match the environment's dynamics
        new_state = state.copy()
        
        # Update based on action
        if action == 0:  # Move forward
            # Update position based on orientation and velocity
            new_state[0] = pos_x + velocity * np.cos(orientation) * 0.1
            new_state[1] = pos_y + velocity * np.sin(orientation) * 0.1
            # Small changes to velocity
            new_state[3] = velocity + np.random.normal(0, 0.05)
        
        elif action == 1:  # Turn left
            # Update orientation (counter-clockwise)
            new_state[2] = orientation + 0.1
            # Update steering
            new_state[4] = steering + np.random.normal(0.05, 0.02)
        
        elif action == 2:  # Turn right
            # Update orientation (clockwise)
            new_state[2] = orientation - 0.1
            # Update steering
            new_state[4] = steering + np.random.normal(-0.05, 0.02)
        
        # For action 3 (stay), minimal changes
        
        # Small random changes to sensor calibration
        new_state[5:10] = state[5:10] + np.random.normal(0, 0.01, 5)
        
        # Apply correlation structure similar to environment
        # This is simplified for example purposes
        scales = np.abs(new_state - state) + 0.01
        corr_matrix = self._get_correlation_matrix()
        cov_matrix = np.outer(scales, scales) * corr_matrix
        
        # Generate correlated noise
        correlated_noise = np.random.multivariate_normal(
            mean=np.zeros(self.state_dim), cov=cov_matrix)
        
        # Scale the noise and add to state
        direction = np.sign(new_state - state)
        magnitude = np.abs(new_state - state)
        correlated_delta = direction * (magnitude + 0.1 * correlated_noise)
        
        # Apply correlated update
        new_state = state + correlated_delta
        
        # Ensure the robot stays within the map boundaries
        new_state[0] = np.clip(new_state[0], 0, self.env_params["map_size"])
        new_state[1] = np.clip(new_state[1], 0, self.env_params["map_size"])
        
        # Normalize orientation to [0, 2π)
        new_state[2] = new_state[2] % (2*np.pi)
        
        # Normalize feature descriptors
        feature_norm = np.linalg.norm(new_state[10:20])
        if feature_norm > 0:
            new_state[10:20] = new_state[10:20] / feature_norm
        
        return new_state
    
    def observation_model(self, state, observation):
        """
        Observation likelihood model for particle belief updates.
        
        Args:
            state: Current state particle
            observation: Observation received
            
        Returns:
            Likelihood of the observation given the state
        """
        # Generate expected observation from the state
        expected_obs = self._get_expected_observation(state)
        
        # Noise level adaptation based on sensor calibration
        noise_level = self.env_params["noise_level"] * (1 + 0.2 * np.sum(np.abs(state[5:10])))
        
        # Compute likelihood using Gaussian noise model
        # Focus on non-zero observations (visible landmarks)
        obs_likelihood = 1.0
        
        # For each landmark observation (distance and feature similarity pairs)
        for i in range(len(observation) // 2):
            # Distance observation
            dist_obs = observation[i*2]
            dist_exp = expected_obs[i*2]
            
            # Feature similarity observation
            feat_obs = observation[i*2 + 1]
            feat_exp = expected_obs[i*2 + 1]
            
            # If landmark is visible in both observations
            if dist_obs > 0 and dist_exp > 0:
                # Compute distance likelihood
                dist_likelihood = np.exp(-0.5 * ((dist_obs - dist_exp) / noise_level)**2)
                
                # Compute feature likelihood
                feat_likelihood = np.exp(-0.5 * ((feat_obs - feat_exp) / noise_level)**2)
                
                # Combine likelihoods
                landmark_likelihood = dist_likelihood * feat_likelihood
                obs_likelihood *= landmark_likelihood
            elif dist_obs > 0 or dist_exp > 0:
                # Penalize if landmark visibility doesn't match
                obs_likelihood *= 0.1
        
        return obs_likelihood
    
    def _get_expected_observation(self, state):
        """
        Generate expected observation from a state.
        This mimics the environment's observation generation.
        
        Args:
            state: State to generate observation from
            
        Returns:
            Expected observation
        """
        # Extract state components
        pos_x, pos_y = state[0], state[1]
        orientation = state[2]
        sensor_calibration = state[5:10]
        
        # Initialize observation vector
        n_landmarks = self.env_params["n_landmarks"]
        observation = np.zeros(n_landmarks * 2)
        
        # Calculate field of view edges
        fov_angle = np.pi / 2  # 90-degree field of view
        fov_start = orientation - fov_angle / 2
        fov_end = orientation + fov_angle / 2
        
        # For each landmark, check if it's within sensor range and field of view
        landmark_idx = 0
        for (lm_x, lm_y), feature in self.env.map.items():
            # Calculate distance to landmark
            dx = lm_x - pos_x
            dy = lm_y - pos_y
            distance = np.sqrt(dx**2 + dy**2)
            
            # Calculate angle to landmark
            angle = np.arctan2(dy, dx)
            
            # Normalize angle to [0, 2π)
            angle = angle % (2*np.pi)
            
            # Check if angle is within field of view
            in_fov = False
            if fov_start <= fov_end:
                in_fov = fov_start <= angle <= fov_end
            else:  # Field of view crosses the 0 angle
                in_fov = (angle >= fov_start) or (angle <= fov_end)
            
            # If landmark is within range and field of view
            if distance <= self.env_params["sensor_range"] and in_fov:
                # Add noise to distance based on sensor calibration
                noise_scale = self.env_params["noise_level"] * (1 + 0.2 * np.sum(sensor_calibration))
                noisy_distance = distance + np.random.normal(0, noise_scale)
                
                # Normalize distance observation
                normalized_distance = noisy_distance / self.env_params["sensor_range"]
                
                # Calculate feature similarity with the robot's internal representation
                feature_similarity = np.dot(feature, state[10:20])
                
                # Add noise to feature similarity
                noisy_similarity = feature_similarity + np.random.normal(0, noise_scale)
                
                # Store the observation if within limits
                if landmark_idx < n_landmarks:
                    observation[landmark_idx*2] = normalized_distance
                    observation[landmark_idx*2 + 1] = noisy_similarity
                    landmark_idx += 1
        
        return observation
    
    def _get_correlation_matrix(self):
        """
        Get correlation matrix for the state variables.
        
        Returns:
            20x20 correlation matrix
        """
        # This should match the environment's correlation matrix
        corr_matrix = np.eye(self.state_dim)
        
        # Position (x, y) and orientation are correlated
        corr_matrix[0, 2] = corr_matrix[2, 0] = 0.6  # x and orientation
        corr_matrix[1, 2] = corr_matrix[2, 1] = 0.6  # y and orientation
        
        # Position and velocity are correlated
        corr_matrix[0, 3] = corr_matrix[3, 0] = 0.8  # x and velocity_x
        corr_matrix[1, 4] = corr_matrix[4, 1] = 0.8  # y and velocity_y
        
        # Orientation and steering are correlated
        corr_matrix[2, 4] = corr_matrix[4, 2] = 0.7  # orientation and steering
        
        # Sensor calibration parameters have internal correlations
        for i in range(5, 10):
            for j in range(5, 10):
                if i != j:
                    corr_matrix[i, j] = corr_matrix[j, i] = 0.4
        
        # Sensor calibration affects feature descriptors
        for i in range(5, 10):
            for j in range(10, 20):
                corr_matrix[i, j] = corr_matrix[j, i] = 0.3
        
        # Feature descriptors have internal correlations based on similarities
        for i in range(10, 20):
            for j in range(i+1, 20):
                corr_matrix[i, j] = corr_matrix[j, i] = 0.2
        
        return corr_matrix
    
    def run_experiment(self, methods, n_episodes=10, max_steps=100, n_particles=100):
        """
        Run experiment to compare different belief approximation methods.
        
        Args:
            methods: Dictionary mapping method names to implementation objects
            n_episodes: Number of episodes to run
            max_steps: Maximum steps per episode
            n_particles: Number of particles for belief representation
            
        Returns:
            DataFrame with evaluation results
        """
        # Initialize results dictionary
        results = {
            "Method": [],
            "Episode": [],
            "Steps": [],
            "Final Position Error": [],
            "Mean Position Error": [],
            "Max Position Error": [],
            "MMD": [],
            "Sliced Wasserstein": [],
            "Correlation Error": [],
            "Mode Coverage": [],
            "ESS": [],
            "Runtime": []
        }
        
        # Storage for belief particles
        all_particles = {method_name: [] for method_name in methods.keys()}
        ground_truth_states = []
        
        # Run episodes
        for episode in range(n_episodes):
            print(f"Episode {episode+1}/{n_episodes}")
            
            # Reset environment
            obs = self.env.reset()
            self.true_state = self.env.state.copy()
            
            # Initialize particles randomly for each method
            particles = {}
            for method_name in methods.keys():
                # Initialize particles with appropriate ranges
                init_particles = np.zeros((n_particles, self.state_dim))
                
                # Position (uniformly distributed in map)
                init_particles[:, 0:2] = np.random.uniform(
                    0, self.env_params["map_size"], (n_particles, 2))
                
                # Orientation (uniformly distributed in [0, 2π))
                init_particles[:, 2] = np.random.uniform(0, 2*np.pi, n_particles)
                
                # Velocity and steering (normal distribution)
                init_particles[:, 3:5] = np.random.normal(0, 0.5, (n_particles, 2))
                
                # Sensor calibration (normal distribution)
                init_particles[:, 5:10] = np.random.normal(0, 0.1, (n_particles, 5))
                
                # Feature descriptors (normalized random vectors)
                features = np.random.normal(0, 1, (n_particles, 10))
                for i in range(n_particles):
                    norm = np.linalg.norm(features[i])
                    if norm > 0:
                        features[i] = features[i] / norm
                init_particles[:, 10:20] = features
                
                particles[method_name] = init_particles
            
            # Run episode
            for step in range(max_steps):
                # Save ground truth state for evaluation
                ground_truth_states.append(self.true_state.copy())
                
                # Random action selection for simplicity
                # In a real application, you'd use a policy based on beliefs
                action = np.random.randint(0, 4)
                
                # Take step in environment
                next_obs, reward, done, info = self.env.step(action)
                self.true_state = self.env.state.copy()  # Update true state reference
                
                # Update belief for each method
                for method_name, method in methods.items():
                    start_time = time.time()
                    
                    # Update belief using the method
                    try:
                        # Different methods have different interfaces
                        if hasattr(method, 'update'):
                            # Standard update interface
                            method.update(action, next_obs, self.transition_model, self.observation_model)
                            updated_particles = method.get_belief_estimate()
                        elif hasattr(method, 'fit_transform'):
                            # Some methods use fit_transform pattern
                            updated_particles = method.fit_transform(
                                particles[method_name],
                                lambda x: self.observation_model(x, next_obs),
                                None
                            )
                        else:
                            # Fallback: assume callable method
                            updated_particles = method(
                                particles[method_name],
                                lambda x: self.observation_model(x, next_obs)
                            )
                        
                        # Update particles
                        particles[method_name] = updated_particles
                    except Exception as e:
                        print(f"Error updating {method_name}: {e}")
                        traceback.print_exc()
                        # Keep previous particles if update fails
                    
                    # Record runtime
                    runtime = time.time() - start_time
                    
                    # Store belief for visualization
                    all_particles[method_name].append(particles[method_name].copy())
                
                # Visualize beliefs periodically
                if step % 10 == 0 or step == max_steps - 1:
                    self._visualize_beliefs(particles, episode, step)
                
                # Update observation
                obs = next_obs
                
                if done:
                    break
            
            # Evaluate performance for each method
            for method_name, method_particles in particles.items():
                # Compute metrics
                position_error = self._compute_position_error(method_particles, self.true_state)
                belief_metrics = self._evaluate_belief_quality(method_particles, self.true_state)
                
                # Record results
                results["Method"].append(method_name)
                results["Episode"].append(episode)
                results["Steps"].append(step + 1)
                results["Final Position Error"].append(position_error["final"])
                results["Mean Position Error"].append(position_error["mean"])
                results["Max Position Error"].append(position_error["max"])
                results["MMD"].append(belief_metrics["mmd"])
                results["Sliced Wasserstein"].append(belief_metrics["sliced_wasserstein"])
                results["Correlation Error"].append(belief_metrics["correlation_error"])
                results["Mode Coverage"].append(belief_metrics["mode_coverage"])
                results["ESS"].append(belief_metrics["ess"])
                results["Runtime"].append(belief_metrics["runtime"])
                
                print(f"{method_name}: Final Pos Error = {position_error['final']:.2f}, "
                      f"MMD = {belief_metrics['mmd']:.4f}, "
                      f"Mode Coverage = {belief_metrics['mode_coverage']:.2f}")
        
        # Convert results to DataFrame
        results_df = pd.DataFrame(results)
        
        # Save results
        results_df.to_csv(os.path.join(self.save_dir, "belief_results.csv"), index=False)
        
        # Generate summary visualizations
        self._visualize_results_summary(results_df)
        
        # Optionally, save particles for further analysis
        np.save(os.path.join(self.save_dir, "ground_truth_states.npy"), np.array(ground_truth_states))
        for method_name, method_particles in all_particles.items():
            np.save(os.path.join(self.save_dir, f"{method_name}_particles.npy"), 
                   np.array(method_particles))
        
        return results_df
    
    def _compute_position_error(self, particles, true_state):
        """
        Compute position estimation error from belief particles.
        
        Args:
            particles: Belief particles
            true_state: True robot state
            
        Returns:
            Dictionary with error metrics
        """
        # Extract position components
        true_pos = true_state[0:2]
        particle_positions = particles[:, 0:2]
        
        # Compute weighted mean position from particles
        mean_pos = np.mean(particle_positions, axis=0)
        
        # Compute error
        final_error = np.linalg.norm(true_pos - mean_pos)
        
        # Compute errors for all particles
        all_errors = np.linalg.norm(particle_positions - true_pos, axis=1)
        mean_error = np.mean(all_errors)
        max_error = np.max(all_errors)
        
        return {
            "final": final_error,
            "mean": mean_error,
            "max": max_error
        }
    
    def _evaluate_belief_quality(self, particles, true_state):
        """
        Evaluate quality of belief approximation using multiple metrics.
        
        Args:
            particles: Belief particles
            true_state: True robot state
            
        Returns:
            Dictionary with metrics
        """
        # Compute MMD (Maximum Mean Discrepancy)
        # Since we don't have true distribution samples, we use 
        # a proxy by creating samples around the true state
        n_reference = len(particles)
        proxy_samples = np.zeros((n_reference, self.state_dim))
        
        # Create proxy samples using the correlation structure
        corr_matrix = self._get_correlation_matrix()
        proxy_samples = np.random.multivariate_normal(
            mean=true_state, 
            cov=0.1 * corr_matrix,  # Scale for reasonable spread
            size=n_reference
        )
        
        # Ensure proxy samples respect constraints
        # Position bounds
        proxy_samples[:, 0] = np.clip(proxy_samples[:, 0], 0, self.env_params["map_size"])
        proxy_samples[:, 1] = np.clip(proxy_samples[:, 1], 0, self.env_params["map_size"])
        
        # Orientation normalization
        proxy_samples[:, 2] = proxy_samples[:, 2] % (2*np.pi)
        
        # Feature normalization
        for i in range(n_reference):
            feature_norm = np.linalg.norm(proxy_samples[i, 10:20])
            if feature_norm > 0:
                proxy_samples[i, 10:20] = proxy_samples[i, 10:20] / feature_norm
        
        # Calculate Maximum Mean Discrepancy
        mmd = self._compute_mmd(particles, proxy_samples)
        
        # Calculate Sliced Wasserstein Distance
        sliced_wasserstein = self._compute_sliced_wasserstein(particles, proxy_samples)
        
        # Calculate Correlation Error
        correlation_error = self._compute_correlation_error(particles, true_state)
        
        # Calculate Mode Coverage
        mode_coverage = self._compute_mode_coverage(particles, true_state)
        
        # Calculate Effective Sample Size
        ess = self._compute_ess(particles, true_state)
        
        # Here we add a simple runtime measurement
        # In practice, you would track actual method runtime
        runtime = 0.0  # Placeholder
        
        return {
            "mmd": mmd,
            "sliced_wasserstein": sliced_wasserstein,
            "correlation_error": correlation_error,
            "mode_coverage": mode_coverage,
            "ess": ess,
            "runtime": runtime
        }
    
    def _compute_mmd(self, particles, reference_samples, bandwidth=None):
        """
        Compute Maximum Mean Discrepancy between particles and reference samples.
        
        Args:
            particles: Belief particles
            reference_samples: Reference distribution samples
            bandwidth: Kernel bandwidth (if None, median heuristic is used)
            
        Returns:
            MMD value
        """
        # Use median heuristic if bandwidth not provided
        if bandwidth is None:
            # Compute pairwise distances for a subset of particles
            n_subset = min(100, len(particles))
            subset_p = particles[:n_subset]
            
            dists = []
            for i in range(min(20, len(subset_p))):
                xi = subset_p[i]
                diff = subset_p - xi
                dists.extend(np.sum(diff**2, axis=1).tolist())
                
            bandwidth = np.median(dists) if dists else 1.0
        
        # RBF kernel function
        def kernel(x, y):
            # Add this safety check to ensure bandwidth is never zero
            safe_bandwidth = max(bandwidth, 1e-10)  # Use a small positive value as the minimum

            # Compute squared distance
            squared_dist = np.sum((x - y)**2)
            
            # Apply exponential with safety measures
            return np.exp(np.clip(-squared_dist / safe_bandwidth, -700, 700))
        
        # Compute MMD terms
        n_p = min(500, len(particles))  # Limit computation for efficiency
        n_r = min(500, len(reference_samples))
        
        # Sample if needed
        if len(particles) > n_p:
            p_indices = np.random.choice(len(particles), n_p, replace=False)
            particles_sub = particles[p_indices]
        else:
            particles_sub = particles
            
        if len(reference_samples) > n_r:
            r_indices = np.random.choice(len(reference_samples), n_r, replace=False)
            reference_sub = reference_samples[r_indices]
        else:
            reference_sub = reference_samples
        
        # Compute terms
        pp_sum = 0
        for i in range(n_p):
            for j in range(i+1, n_p):
                pp_sum += kernel(particles_sub[i], particles_sub[j])
        pp_sum = 2 * pp_sum / (n_p * (n_p - 1)) if n_p > 1 else 0
        
        rr_sum = 0
        for i in range(n_r):
            for j in range(i+1, n_r):
                rr_sum += kernel(reference_sub[i], reference_sub[j])
        rr_sum = 2 * rr_sum / (n_r * (n_r - 1)) if n_r > 1 else 0
        
        pr_sum = 0
        for i in range(n_p):
            for j in range(n_r):
                pr_sum += kernel(particles_sub[i], reference_sub[j])
        pr_sum = pr_sum / (n_p * n_r)
        
        mmd = pp_sum + rr_sum - 2 * pr_sum
        return max(0, mmd)  # Ensure non-negative
    
    def _compute_sliced_wasserstein(self, particles, reference_samples, n_projections=20):
        """
        Compute Sliced Wasserstein Distance between particles and reference samples.
        
        Args:
            particles: Belief particles
            reference_samples: Reference distribution samples
            n_projections: Number of random projections
            
        Returns:
            Sliced Wasserstein Distance
        """
        try:
            # Generate random projection directions
            directions = np.random.randn(n_projections, self.state_dim)
            directions = directions / np.linalg.norm(directions, axis=1, keepdims=True)
            
            # Limit sample size for efficiency
            n_p = min(500, len(particles))
            n_r = min(500, len(reference_samples))
            
            # Sample if needed
            if len(particles) > n_p:
                p_indices = np.random.choice(len(particles), n_p, replace=False)
                particles_sub = particles[p_indices]
            else:
                particles_sub = particles
                
            if len(reference_samples) > n_r:
                r_indices = np.random.choice(len(reference_samples), n_r, replace=False)
                reference_sub = reference_samples[r_indices]
            else:
                reference_sub = reference_samples
            
            # Compute Sliced Wasserstein Distance
            swd = 0.0
            
            for direction in directions:
                # Project samples onto this direction
                particles_proj = particles_sub @ direction
                reference_proj = reference_sub @ direction
                
                # Sort projections
                particles_proj = np.sort(particles_proj)
                reference_proj = np.sort(reference_proj)
                
                # Compute 1-Wasserstein distance for this projection
                if len(particles_proj) != len(reference_proj):
                    # Interpolate to match lengths
                    if len(particles_proj) > len(reference_proj):
                        indices = np.linspace(0, len(reference_proj)-1, len(particles_proj))
                        reference_proj_interp = np.interp(indices, np.arange(len(reference_proj)), reference_proj)
                        w_dist = np.mean(np.abs(particles_proj - reference_proj_interp))
                    else:
                        indices = np.linspace(0, len(particles_proj)-1, len(reference_proj))
                        particles_proj_interp = np.interp(indices, np.arange(len(particles_proj)), particles_proj)
                        w_dist = np.mean(np.abs(particles_proj_interp - reference_proj))
                else:
                    w_dist = np.mean(np.abs(particles_proj - reference_proj))
                
                swd += w_dist
            
            return swd / n_projections
        except Exception as e:
            print(f"Error computing sliced Wasserstein distance: {e}")
            return float('inf')
    
    def _compute_correlation_error(self, particles, true_state):
        """
        Compute error in capturing correlation structure.
        
        Args:
            particles: Belief particles
            true_state: True robot state
            
        Returns:
            Correlation error metric
        """
        # Compute empirical covariance from particles
        if len(particles) < 10:  # Need enough particles
            return 1.0
            
        empirical_cov = np.cov(particles, rowvar=False)
        
        # Get true correlation matrix (approximation for this task)
        true_corr = self._get_correlation_matrix()
        
        # Normalize empirical covariance to correlation
        d = np.sqrt(np.diag(empirical_cov))
        d_mat = np.outer(d, d)
        # Avoid division by zero
        d_mat[d_mat < 1e-10] = 1.0
        empirical_corr = empirical_cov / d_mat
        
        # Compute Frobenius norm of the difference
        diff_norm = np.linalg.norm(empirical_corr - true_corr, 'fro')
        true_norm = np.linalg.norm(true_corr, 'fro')
        
        # Normalize the error
        if true_norm > 1e-10:
            return diff_norm / true_norm
        else:
            return 1.0
    
    def _compute_mode_coverage(self, particles, true_state):
        """
        Compute mode coverage for the robotics task.
        
        In the Kidnapped Robot problem, there can be multiple plausible
        locations (modes) due to perceptual aliasing.
        
        Args:
            particles: Belief particles
            true_state: True robot state
            
        Returns:
            Mode coverage ratio [0, 1]
        """
        # For this task, we consider position modes:
        # Focus on x, y only for clustering
        pos_particles = particles[:, 0:2]
        
        # Cluster to find modes
        clustering = DBSCAN(eps=2.0, min_samples=3).fit(pos_particles)
        labels = clustering.labels_
        
        # Number of detected modes
        n_modes = len(set(labels)) - (1 if -1 in labels else 0)
        
        # Check if true position is covered
        true_pos = true_state[0:2]
        
        # Find distances from true position to each particle
        dists = np.sqrt(np.sum((pos_particles - true_pos)**2, axis=1))
        
        # Check if at least some particles are close to true position
        close_particles = np.sum(dists < 3.0)  # Within 3 units
        
        # At least 3% of particles should be close to true position
        true_pos_covered = close_particles >= max(3, 0.03 * len(particles))
        
        # Compute mode coverage score:
        # - If true position is covered and there are multiple modes: excellent
        # - If true position is covered but only 1 mode: good
        # - If true position is not covered: poor
        
        # Need at least 2 modes to handle perceptual aliasing correctly
        if n_modes >= 2 and true_pos_covered:
            return 1.0
        elif n_modes == 1 and true_pos_covered:
            return 0.8
        elif true_pos_covered:
            return 0.5
        else:
            return 0.0
    
    def _compute_ess(self, particles, true_state):
        """
        Compute Effective Sample Size ratio.
        
        Args:
            particles: Belief particles
            true_state: True robot state
            
        Returns:
            ESS ratio [0, 1]
        """
        # Compute observation likelihood for each particle
        obs = self._get_expected_observation(true_state)
        log_weights = np.zeros(len(particles))
        
        for i, particle in enumerate(particles):
            likelihood = self.observation_model(particle, obs)
            log_weights[i] = np.log(max(likelihood, 1e-10))
        
        # Normalize log weights
        log_weights -= np.max(log_weights)
        weights = np.exp(log_weights)
        sum_weights = np.sum(weights)
        
        if sum_weights < 1e-10:
            return 0.0
            
        weights /= sum_weights
        
        # Compute ESS
        ess = 1.0 / np.sum(weights**2)
        
        # Normalize by number of particles
        return ess / len(particles)
    
    def _visualize_beliefs(self, particles_dict, episode, step):
        """
        Visualize current belief particles for all methods.
        
        Args:
            particles_dict: Dictionary with particles for each method
            episode: Current episode number
            step: Current step number
        """
        n_methods = len(particles_dict)
        
        fig = plt.figure(figsize=(15, 5 * n_methods))
        fig.suptitle(f"Episode {episode+1}, Step {step+1} - Belief Visualization", fontsize=16)
        
        for i, (method_name, particles) in enumerate(particles_dict.items()):
            ax = fig.add_subplot(n_methods, 1, i+1)
            
            # Plot the map and landmarks
            for (lm_x, lm_y), _ in self.env.map.items():
                ax.scatter(lm_x, lm_y, c='blue', marker='^', s=100, label='Landmark')
            
            # Plot the true robot position
            true_x, true_y = self.true_state[0], self.true_state[1]
            true_theta = self.true_state[2]
            
            ax.scatter(true_x, true_y, c='red', marker='o', s=200, label='True Position')
            
            # Plot true orientation
            arrow_length = 0.5
            dx = arrow_length * np.cos(true_theta)
            dy = arrow_length * np.sin(true_theta)
            ax.arrow(true_x, true_y, dx, dy, head_width=0.3, head_length=0.3, fc='red', ec='red')
            
            # Plot sensor range
            circle = plt.Circle((true_x, true_y), self.env_params["sensor_range"], 
                             fill=False, linestyle='--', color='green')
            ax.add_patch(circle)
            
            # Plot field of view
            fov_angle = np.pi / 2  # 90-degree field of view
            theta1 = true_theta - fov_angle / 2
            theta2 = true_theta + fov_angle / 2
            
            # Draw field of view lines
            line_length = self.env_params["sensor_range"]
            ax.plot([true_x, true_x + line_length * np.cos(theta1)],
                    [true_y, true_y + line_length * np.sin(theta1)],
                    'g--')
            ax.plot([true_x, true_x + line_length * np.cos(theta2)],
                    [true_y, true_y + line_length * np.sin(theta2)],
                    'g--')
            
            # Plot particles
            particles_x = particles[:, 0]
            particles_y = particles[:, 1]
            
            ax.scatter(particles_x, particles_y, c='orange', marker='.', 
                     s=10, alpha=0.5, label='Particles')
            
            # Plot orientation for a subset of particles (for clarity)
            n_arrows = min(20, len(particles))
            arrow_indices = np.random.choice(len(particles), n_arrows, replace=False)
            
            for idx in arrow_indices:
                p_x, p_y = particles[idx, 0], particles[idx, 1]
                p_theta = particles[idx, 2]
                
                dx = 0.3 * np.cos(p_theta)
                dy = 0.3 * np.sin(p_theta)
                
                ax.arrow(p_x, p_y, dx, dy, head_width=0.1, head_length=0.1, 
                        fc='orange', ec='orange', alpha=0.5)
            
            # Calculate and plot particle density for visualization
            try:
                # Create a 2D histogram of particle positions
                heatmap, xedges, yedges = np.histogram2d(
                    particles_x, particles_y, bins=20, 
                    range=[[0, self.env_params["map_size"]], [0, self.env_params["map_size"]]])
                
                # Plot heatmap with transparency
                extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
                ax.imshow(heatmap.T, extent=extent, origin='lower', 
                         cmap='YlOrRd', alpha=0.3, interpolation='bilinear')
            except:
                pass  # Skip heatmap if there's an error
            
            # Compute and display metrics
            error = self._compute_position_error(particles, self.true_state)
            metrics = self._evaluate_belief_quality(particles, self.true_state)
            
            # Add metrics as text
            metrics_text = (
                f"Method: {method_name}\n"
                f"Position Error: {error['final']:.2f}\n"
                f"MMD: {metrics['mmd']:.4f}\n"
                f"Corr. Error: {metrics['correlation_error']:.2f}\n"
                f"Mode Coverage: {metrics['mode_coverage']:.2f}"
            )
            
            ax.text(0.02, 0.98, metrics_text, transform=ax.transAxes, 
                   verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
            
            # Set plot limits and labels
            ax.set_xlim(0, self.env_params["map_size"])
            ax.set_ylim(0, self.env_params["map_size"])
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            ax.set_title(f"{method_name} Belief", fontsize=14)
            
            # Show legend without duplicate entries
            handles, labels = ax.get_legend_handles_labels()
            by_label = dict(zip(labels, handles))
            ax.legend(by_label.values(), by_label.keys(), loc='upper right')
            
            ax.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "plots", f"belief_ep{episode+1}_step{step+1}.png"))
        plt.close(fig)
    
    def _visualize_results_summary(self, results_df):
        """
        Visualize summary of experiment results.
        
        Args:
            results_df: DataFrame with evaluation results
        """
        # Aggregate results across episodes
        summary = results_df.groupby('Method').mean().reset_index()
        
        # Plot metrics comparison
        metrics = [
            'Final Position Error', 'MMD', 'Sliced Wasserstein', 
            'Correlation Error', 'Mode Coverage', 'ESS'
        ]
        
        fig = plt.figure(figsize=(15, 10))
        fig.suptitle("Performance Comparison of Belief Approximation Methods", fontsize=16)
        
        for i, metric in enumerate(metrics):
            ax = fig.add_subplot(2, 3, i+1)
            
            # Sort by metric (asc or desc depending on metric)
            if metric in ['Mode Coverage', 'ESS']:
                # Higher is better
                sorted_df = summary.sort_values(metric, ascending=False)
            else:
                # Lower is better
                sorted_df = summary.sort_values(metric)
            
            # Create bar plot
            bars = ax.bar(sorted_df['Method'], sorted_df[metric])
            
            # Add value annotations
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height, 
                       f'{height:.3f}', ha='center', va='bottom', rotation=0)
            
            # Set labels and title
            ax.set_xlabel('Method')
            ax.set_ylabel(metric)
            
            if metric in ['Mode Coverage', 'ESS']:
                ax.set_title(f"{metric} (higher is better)")
                ax.set_ylim(0, 1.1)
            else:
                ax.set_title(f"{metric} (lower is better)")
                
            ax.grid(axis='y', alpha=0.3)
            plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "metrics_comparison.png"))
        plt.close(fig)
        
        # Create combined metrics plot
        # Normalize metrics for combined view
        normalized_df = summary.copy()
        
        for metric in metrics:
            if metric in ['Mode Coverage', 'ESS']:
                # Higher is better, already normalized
                normalized_df[f"{metric}_norm"] = normalized_df[metric]
            else:
                # Lower is better, invert normalization
                if normalized_df[metric].max() != normalized_df[metric].min():
                    normalized_df[f"{metric}_norm"] = 1 - (normalized_df[metric] - normalized_df[metric].min()) / (normalized_df[metric].max() - normalized_df[metric].min())
                else:
                    normalized_df[f"{metric}_norm"] = 1.0
        
        # Plot radar chart for each method
        fig = plt.figure(figsize=(15, 10))
        fig.suptitle("Normalized Performance Metrics (higher is better)", fontsize=16)
        
        # Number of metrics
        N = len(metrics)
        
        # Angle of each axis
        angles = np.linspace(0, 2*np.pi, N, endpoint=False).tolist()
        angles += angles[:1]  # Close the loop
        
        # Position for each metric label
        label_angles = np.linspace(0, 2*np.pi, N, endpoint=False)
        label_positions = []
        for angle in label_angles:
            label_positions.append((np.cos(angle), np.sin(angle)))
        
        # Plot each method as a separate radar chart
        n_methods = len(normalized_df)
        n_cols = min(3, n_methods)
        n_rows = (n_methods + n_cols - 1) // n_cols
        
        for i, (_, row) in enumerate(normalized_df.iterrows()):
            ax = fig.add_subplot(n_rows, n_cols, i+1, projection='polar')
            
            # Plot metrics
            values = [row[f"{metric}_norm"] for metric in metrics]
            values += values[:1]  # Close the loop
            
            ax.plot(angles, values, linewidth=2, linestyle='solid')
            ax.fill(angles, values, alpha=0.25)
            
            # Set labels
            ax.set_xticks(angles[:-1])
            ax.set_xticklabels(metrics)
            
            # Adjust label positions for readability
            for label, angle in zip(ax.get_xticklabels(), label_angles):
                x, y = label.get_position()
                lab = ax.get_xticklabels()[list(label_angles).index(angle)]
                lab.set_horizontalalignment('center')
                lab.set_verticalalignment('center')
                lab.set_rotation(np.degrees(angle))
            
            ax.set_ylim(0, 1)
            ax.set_title(row['Method'])
            
            # Add gridlines
            ax.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "radar_comparison.png"))
        plt.close(fig)
        
        # Create a summary table
        fig, ax = plt.subplots(figsize=(12, 7))
        ax.axis('off')
        
        table_data = []
        table_headers = ['Method'] + metrics + ['Runtime']
        table_data.append(table_headers)
        
        for _, row in summary.iterrows():
            table_row = [row['Method']]
            for metric in metrics:
                table_row.append(f"{row[metric]:.3f}")
            table_row.append(f"{row['Runtime']:.3f}")
            table_data.append(table_row)
        
        table = ax.table(cellText=table_data, loc='center', cellLoc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(12)
        table.scale(1, 1.5)
        
        plt.title("Summary of Belief Approximation Performance", fontsize=16, pad=20)
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "summary_table.png"))
        plt.close(fig)

