import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle
import pandas as pd
from tqdm import tqdm
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
import os
import time
import traceback

# Import the MultiTargetTracking environment
from multi_target_tracking.multi_target_tracking_env import MultiTargetTracking20DEnv

class MultiTargetTracking20DRunner:
    """
    # Runner for the 20D Multi-Target Tracking POMDP problem that evaluates different
    # belief approximation methods on various metrics.
    
    This runner handles:
    1. Environment creation and episode progression
    2. Belief update for each method
    3. Evaluation using multiple metrics
    4. Visualization of beliefs and performances
    """
    
    def __init__(self, env_params=None, save_dir="results_mtt"):
        """
        Initialize the runner with environment parameters.
        
        Args:
            env_params: Dictionary of parameters for the MultiTargetTracking20DEnv
            save_dir: Directory to save results and visualizations
        """
        # Default environment parameters
        self.env_params = {
            "map_size": 10,
            "n_targets": 4,  # Fixed for 20D state space
            "noise_level": 0.5,
        }
        
        # Update with user-provided parameters
        if env_params is not None:
            self.env_params.update(env_params)
        
        # Create environment
        self.env = MultiTargetTracking20DEnv(**self.env_params)
        
        # Storage for results
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        os.makedirs(os.path.join(save_dir, "plots"), exist_ok=True)
        
        # State dimension
        self.state_dim = 20  # Fixed for this environment
        
        # Reference to ground truth state for evaluation
        self.true_state = None
    
    def transition_model(self, state, action):
        """
        Transition model for particle belief updates.
        
        Args:
            state: Current state particle
            action: Action taken
            
        Returns:
            Updated state particle
        """
        # Extract components
        agent_pos = state[0:2]
        agent_vel = state[2:4]
        
        # Calculate agent acceleration from action
        agent_acc = np.zeros(2)
        if action == 0:  # +x
            agent_acc[0] = 0.1
        elif action == 1:  # -x
            agent_acc[0] = -0.1
        elif action == 2:  # +y
            agent_acc[1] = 0.1
        elif action == 3:  # -y
            agent_acc[1] = -0.1
        
        # Update agent velocity with damping
        damping = 0.1
        agent_vel = agent_vel * (1 - damping) + agent_acc
        
        # Update agent position
        dt = 0.1
        agent_pos = agent_pos + agent_vel * dt
        
        # Update state
        new_state = state.copy()
        new_state[0:2] = agent_pos
        new_state[2:4] = agent_vel
        
        # Update each target with flow influence and collision avoidance
        for i in range(self.env_params["n_targets"]):
            # Extract target state
            pos_idx = 4 + 4*i
            vel_idx = pos_idx + 2
            
            target_pos = new_state[pos_idx:pos_idx+2]
            target_vel = new_state[vel_idx:vel_idx+2]
            
            # Get flow influence
            if hasattr(self.env, '_get_flow_influence'):
                flow = self.env._get_flow_influence(target_pos)
            else:
                flow = np.zeros(2)
            
            # Apply random acceleration
            target_acc = np.random.normal(0, 0.05, 2)
            
            # Update target velocity
            target_vel = target_vel * (1 - damping) + target_acc + flow
            
            # Update target position
            target_pos = target_pos + target_vel * dt
            
            # Collision avoidance
            for j in range(self.env_params["n_targets"]):
                if i != j:
                    other_pos_idx = 4 + 4*j
                    other_pos = new_state[other_pos_idx:other_pos_idx+2]
                    
                    # Calculate distance
                    dist = np.linalg.norm(target_pos - other_pos)
                    if dist < 1.0:  # Collision avoidance radius
                        # Repulsive force
                        direction = (target_pos - other_pos) / (dist + 0.1)
                        repulsive_force = direction * (1.0 - dist) * 0.1
                        target_vel += repulsive_force
            
            # Update state
            new_state[pos_idx:pos_idx+2] = target_pos
            new_state[vel_idx:vel_idx+2] = target_vel
        
        # Apply correlation structure
        if hasattr(self.env, '_apply_correlation'):
            # Generate random delta to be correlated
            state_delta = np.random.normal(0, 0.01, self.state_dim)
            correlated_delta = self.env._apply_correlation(state_delta)
            new_state += correlated_delta
        else:
            # Add some correlated noise using correlation matrix
            corr_matrix = self._get_correlation_matrix()
            noise = np.random.multivariate_normal(np.zeros(self.state_dim), 0.01 * corr_matrix)
            new_state += noise
        
        # Ensure states stay within bounds
        # Agent and targets positions
        for i in range(self.env_params["n_targets"] + 1):
            pos_idx = 4*i
            if pos_idx + 2 <= len(new_state):
                new_state[pos_idx:pos_idx+2] = np.clip(new_state[pos_idx:pos_idx+2], 
                                                     0, self.env_params["map_size"])
        
        return new_state
    
    def observation_model(self, state, observation):
        """
        Observation likelihood model for particle belief updates.
        
        Args:
            state: Current state particle
            observation: Observation received
            
        Returns:
            Likelihood of the observation given the state
        """
        # Extract state components
        agent_pos = state[0:2]
        
        # If observation is None, generate an observation
        if observation is None:
            return self._generate_observation(state)
        
        # Compute likelihood
        likelihood = 1.0
        
        # Agent position likelihood (small noise)
        agent_noise = 0.01
        for i in range(2):
            diff = observation[i] - agent_pos[i]
            likelihood *= np.exp(-0.5 * (diff / agent_noise)**2) / (agent_noise * np.sqrt(2 * np.pi))
        
        # Target position likelihoods
        for i in range(self.env_params["n_targets"]):
            # True target position
            pos_idx = 4 + 4*i
            target_pos = state[pos_idx:pos_idx+2]
            
            # Observed position
            obs_idx = 2 + 2*i
            obs_pos = observation[obs_idx:obs_idx+2]
            
            # Get visibility parameters
            if hasattr(self.env, '_get_visibility'):
                visibility, confusion, occlusion = self.env._get_visibility(agent_pos, target_pos)
            else:
                # Default values if method not available
                dist_to_agent = np.linalg.norm(target_pos - agent_pos)
                visibility = max(0.1, 1.0 - dist_to_agent / self.env_params["map_size"])
                confusion = False
                occlusion = False
            
            # Skip if occluded
            if occlusion:
                continue
                
            # Noise scale based on visibility
            noise_scale = self.env_params["noise_level"] * (1.0 - visibility) + 0.01
            
            # If observed position is zero (not visible), low likelihood
            if np.all(obs_pos == 0):
                likelihood *= 0.1  # Low probability for unseen targets
            else:
                # Normal likelihood based on distance
                for j in range(2):
                    diff = obs_pos[j] - target_pos[j]
                    likelihood *= np.exp(-0.5 * (diff / noise_scale)**2) / (noise_scale * np.sqrt(2 * np.pi))
                
                # Account for confusion
                if confusion:
                    # Add probability of seeing other targets
                    confusion_likelihood = 0.0
                    
                    for j in range(self.env_params["n_targets"]):
                        if i != j:
                            other_pos_idx = 4 + 4*j
                            other_pos = state[other_pos_idx:other_pos_idx+2]
                            
                            other_likelihood = 1.0
                            for k in range(2):
                                diff = obs_pos[k] - other_pos[k]
                                other_likelihood *= np.exp(-0.5 * (diff / noise_scale)**2) / (noise_scale * np.sqrt(2 * np.pi))
                            
                            confusion_likelihood += 0.3 * other_likelihood  # 30% chance of confusion
                    
                    likelihood = likelihood * 0.7 + confusion_likelihood  # Mix likelihoods
        
        return likelihood
    
    def _generate_observation(self, state):
        """Generate a synthetic observation from a state"""
        # Extract state components
        agent_pos = state[0:2]
        
        # Create observation
        obs_dim = 2 + 2*self.env_params["n_targets"]
        observation = np.zeros(obs_dim)
        
        # Agent position observed with small noise
        observation[0:2] = agent_pos + np.random.normal(0, 0.01, 2)
        
        # Target observations
        for i in range(self.env_params["n_targets"]):
            # Extract target position
            pos_idx = 4 + 4*i
            target_pos = state[pos_idx:pos_idx+2]
            
            # Get visibility parameters
            if hasattr(self.env, '_get_visibility'):
                visibility, confusion, occlusion = self.env._get_visibility(agent_pos, target_pos)
            else:
                # Default values if method not available
                dist_to_agent = np.linalg.norm(target_pos - agent_pos)
                visibility = max(0.1, 1.0 - dist_to_agent / self.env_params["map_size"])
                confusion = False
                occlusion = False
            
            # Observation indices
            obs_idx = 2 + 2*i
            
            # Apply visibility effects
            if occlusion or np.random.random() > visibility:
                # Target not visible
                observation[obs_idx:obs_idx+2] = np.zeros(2)
            else:
                # Apply confusion
                if confusion and np.random.random() < 0.3:
                    # Swap with another target
                    swap_options = [j for j in range(self.env_params["n_targets"]) if j != i]
                    if swap_options:
                        swap_idx = np.random.choice(swap_options)
                        swap_pos_idx = 4 + 4*swap_idx
                        swapped_pos = state[swap_pos_idx:swap_pos_idx+2]
                        observation[obs_idx:obs_idx+2] = swapped_pos
                else:
                    # Normal observation with noise
                    noise_scale = self.env_params["noise_level"] * (1.0 - visibility) + 0.01
                    observation[obs_idx:obs_idx+2] = target_pos + np.random.normal(0, noise_scale, 2)
        
        return observation
    
    def _get_correlation_matrix(self):
        """Get the correlation matrix for the state variables"""
        if hasattr(self.env, 'correlation_matrix'):
            return self.env.correlation_matrix
        else:
            # Create default correlation matrix if not available
            corr_matrix = np.eye(self.state_dim)
            
            # Agent position-velocity correlation
            corr_matrix[0, 2] = corr_matrix[2, 0] = 0.8
            corr_matrix[1, 3] = corr_matrix[3, 1] = 0.8
            
            # Target correlations
            for i in range(self.env_params["n_targets"]):
                # Position-velocity correlations
                pos_x_idx = 4 + 4*i
                pos_y_idx = 4 + 4*i + 1
                vel_x_idx = 4 + 4*i + 2
                vel_y_idx = 4 + 4*i + 3
                
                corr_matrix[pos_x_idx, vel_x_idx] = corr_matrix[vel_x_idx, pos_x_idx] = 0.8
                corr_matrix[pos_y_idx, vel_y_idx] = corr_matrix[vel_y_idx, pos_y_idx] = 0.8
            
            return corr_matrix
    
    def run_experiment(self, methods, n_episodes=10, max_steps=100, n_particles=100):
        """
        Run experiment to compare different belief approximation methods.
        
        Args:
            methods: Dictionary mapping method names to implementation objects
            n_episodes: Number of episodes to run
            max_steps: Maximum steps per episode
            n_particles: Number of particles for belief representation
            
        Returns:
            DataFrame with evaluation results
        """
        # Initialize results dictionary
        results = {
            "Method": [],
            "Episode": [],
            "Steps": [],
            "Final Position Error": [],
            "Mean Position Error": [],
            "Max Position Error": [],
            "MMD": [],
            "Sliced Wasserstein": [],
            "Correlation Error": [],
            "Mode Coverage": [],
            "ESS": [],
            "Runtime": [],
            "Success": [],
            "Final Distance": []
        }
        
        # Storage for belief particles
        all_particles = {method_name: [] for method_name in methods.keys()}
        ground_truth_states = []
        
        # Run episodes
        for episode in range(n_episodes):
            print(f"Episode {episode+1}/{n_episodes}")
            
            # Reset environment
            obs = self.env.reset()
            self.true_state = self.env.state.copy()
            
            # Initialize particles randomly for each method
            particles = {}
            for method_name in methods.keys():
                # Initialize particles with appropriate ranges
                init_particles = np.zeros((n_particles, self.state_dim))
                
                # Agent position (uniformly distributed in first quarter of map)
                init_particles[:, 0:2] = np.random.uniform(
                    0, self.env_params["map_size"], (n_particles, 2))
                
                # Agent velocity (normal distribution around zero)
                init_particles[:, 2:4] = np.random.normal(0, 0.2, (n_particles, 2))
                
                # Target positions (uniformly distributed in map)
                for i in range(self.env_params["n_targets"]):
                    pos_idx = 4 + 4*i
                    vel_idx = pos_idx + 2
                    
                    # Position
                    init_particles[:, pos_idx:pos_idx+2] = np.random.uniform(
                        0, self.env_params["map_size"], (n_particles, 2))
                    
                    # Velocity
                    init_particles[:, vel_idx:vel_idx+2] = np.random.normal(0, 0.2, (n_particles, 2))
                
                particles[method_name] = init_particles
            
            # Run episode
            for step in range(max_steps):
                # Save ground truth state for evaluation
                ground_truth_states.append(self.true_state.copy())
                
                # Select action based on current belief - here using a simple heuristic
                # In a real application, you'd use a policy trained on the belief
                action = self._select_action(self.true_state)
                
                # Take step in environment
                next_obs, reward, done, info = self.env.step(action)
                self.true_state = self.env.state.copy()  # Update true state reference
                
                # Update belief for each method
                for method_name, method in methods.items():
                    start_time = time.time()
                    
                    # Update belief using the method
                    try:
                        # Different methods have different interfaces
                        if hasattr(method, 'update'):
                            # Standard update interface
                            method.update(action, next_obs, self.transition_model, self.observation_model)
                            updated_particles = method.get_belief_estimate()
                        elif hasattr(method, 'fit_transform'):
                            # Some methods use fit_transform pattern
                            updated_particles = method.fit_transform(
                                particles[method_name],
                                lambda x: self.observation_model(x, next_obs),
                                None
                            )
                        else:
                            # Fallback: assume callable method
                            updated_particles = method(
                                particles[method_name],
                                lambda x: self.observation_model(x, next_obs)
                            )
                        
                        # Update particles
                        particles[method_name] = updated_particles
                    except Exception as e:
                        print(f"Error updating {method_name}: {e}")
                        traceback.print_exc()
                        # Keep previous particles if update fails
                    
                    # Record runtime
                    runtime = time.time() - start_time
                    
                    # Store belief for visualization
                    all_particles[method_name].append(particles[method_name].copy())
                
                # Visualize beliefs periodically
                if step % 10 == 0 or step == max_steps - 1:
                    self._visualize_beliefs(particles, episode, step)
                
                # Update observation
                obs = next_obs
                
                if done:
                    break
            
            # Evaluate performance for each method
            for method_name, method_particles in particles.items():
                # Compute metrics
                position_error = self._compute_position_error(method_particles, self.true_state)
                belief_metrics = self._evaluate_belief_quality(method_particles, self.true_state)
                
                # Check success - if we've reached the goal
                distance_to_goal = np.linalg.norm(self.true_state[0:2] - self.env.goal)
                success = distance_to_goal < 0.5
                
                # Record results
                results["Method"].append(method_name)
                results["Episode"].append(episode)
                results["Steps"].append(step + 1)
                results["Final Position Error"].append(position_error["final"])
                results["Mean Position Error"].append(position_error["mean"])
                results["Max Position Error"].append(position_error["max"])
                results["MMD"].append(belief_metrics["mmd"])
                results["Sliced Wasserstein"].append(belief_metrics["sliced_wasserstein"])
                results["Correlation Error"].append(belief_metrics["correlation_error"])
                results["Mode Coverage"].append(belief_metrics["mode_coverage"])
                results["ESS"].append(belief_metrics["ess"])
                results["Runtime"].append(belief_metrics["runtime"])
                results["Success"].append(success)
                results["Final Distance"].append(distance_to_goal)
                
                print(f"{method_name}: Final Pos Error = {position_error['final']:.2f}, "
                      f"MMD = {belief_metrics['mmd']:.4f}, "
                      f"Mode Coverage = {belief_metrics['mode_coverage']:.2f}, "
                      f"Success = {success}")
        
        # Convert results to DataFrame
        results_df = pd.DataFrame(results)
        
        # Save results
        results_df.to_csv(os.path.join(self.save_dir, "belief_results.csv"), index=False)
        
        # Generate summary visualizations
        self._visualize_results_summary(results_df)
        
        # Optionally, save particles for further analysis
        np.save(os.path.join(self.save_dir, "ground_truth_states.npy"), np.array(ground_truth_states))
        for method_name, method_particles in all_particles.items():
            np.save(os.path.join(self.save_dir, f"{method_name}_particles.npy"), 
                   np.array(method_particles))
        
        return results_df
    
    def _select_action(self, state):
        """
        Simple policy to select action based on current state.
        This is just a heuristic - in practice, you'd train a policy.
        
        Args:
            state: Current state
            
        Returns:
            Action index
        """
        # Extract agent position
        agent_pos = state[0:2]
        
        # Calculate direction to goal
        goal_direction = self.env.goal - agent_pos
        
        # Choose action based on largest goal direction component
        if abs(goal_direction[0]) > abs(goal_direction[1]):
            # Move in x direction
            return 0 if goal_direction[0] > 0 else 1  # +x or -x
        else:
            # Move in y direction
            return 2 if goal_direction[1] > 0 else 3  # +y or -y
    
    def _compute_position_error(self, particles, true_state):
        """
        Compute position estimation error from belief particles.
        
        Args:
            particles: Belief particles
            true_state: True robot state
            
        Returns:
            Dictionary with error metrics
        """
        # Extract agent position (first 2 dimensions)
        true_pos = true_state[0:2]
        particle_positions = particles[:, 0:2]
        
        # Compute weighted mean position from particles
        mean_pos = np.mean(particle_positions, axis=0)
        
        # Compute error
        final_error = np.linalg.norm(true_pos - mean_pos)
        
        # Compute errors for all particles
        all_errors = np.linalg.norm(particle_positions - true_pos, axis=1)
        mean_error = np.mean(all_errors)
        max_error = np.max(all_errors)
        
        return {
            "final": final_error,
            "mean": mean_error,
            "max": max_error
        }
    
    def _evaluate_belief_quality(self, particles, true_state):
        """Optimized version of belief quality evaluation using vectorized operations"""
        # Limit sample size for all metrics at once
        n_eval = min(200, len(particles))
        if len(particles) > n_eval:
            # Sample once for all metrics
            eval_indices = np.random.choice(len(particles), n_eval, replace=False)
            particles_subset = particles[eval_indices]
        else:
            particles_subset = particles
        
        # Generate proxy samples around true state (reused across metrics)
        corr_matrix = self._get_correlation_matrix()
        proxy_samples = np.random.multivariate_normal(
            mean=true_state, 
            cov=0.1 * corr_matrix,  # Scale for reasonable spread
            size=n_eval
        )
        
        # Ensure proxy samples respect constraints (all at once)
        for i in range(self.env_params["n_targets"] + 1):
            pos_idx = 4*i
            if pos_idx + 2 <= proxy_samples.shape[1]:
                proxy_samples[:, pos_idx:pos_idx+2] = np.clip(
                    proxy_samples[:, pos_idx:pos_idx+2], 0, self.env_params["map_size"])
        
        # Calculate Maximum Mean Discrepancy (optimized)
        mmd = self._compute_mmd(particles_subset, proxy_samples)
        
        # Calculate Sliced Wasserstein Distance (optimized with fewer projections)
        sliced_wasserstein = self._compute_sliced_wasserstein(
            particles_subset, proxy_samples, n_projections=10)
        
        # Calculate Correlation Error (simplified)
        correlation_error = self._compute_correlation_error(particles_subset)
        
        # Calculate Mode Coverage (with limited clustering)
        mode_coverage = self._compute_mode_coverage(particles, true_state)
        
        # Calculate Effective Sample Size (optimized)
        ess = self._compute_ess(particles_subset, true_state)
        
        # Measure runtime of the evaluation itself
        runtime = 0.0
        
        return {
            "mmd": mmd,
            "sliced_wasserstein": sliced_wasserstein,
            "correlation_error": correlation_error,
            "mode_coverage": mode_coverage,
            "ess": ess,
            "runtime": runtime
        }
    
    def _compute_mmd(self, particles, reference_samples, bandwidth=None):
        """Optimized MMD calculation using vectorized operations"""
        # Use a fixed bandwidth to avoid redundant computations
        bandwidth = 1.0
        
        # Calculate pairwise distances all at once
        def pairwise_distances(X):
            """Compute pairwise squared distances efficiently"""
            # (a-b)^2 = a^2 + b^2 - 2ab
            sum_X = np.sum(X**2, axis=1)
            # reshape for broadcasting
            sum_X_expand = sum_X[:, np.newaxis]
            cross_term = -2 * np.dot(X, X.T)
            sq_dists = sum_X_expand + sum_X + cross_term
            return np.maximum(sq_dists, 0)  # Ensure non-negative
        
        # Apply kernel to all pairs at once
        def kernel_matrix(X, Y=None):
            """Compute kernel matrix efficiently"""
            if Y is None:
                sq_dists = pairwise_distances(X)
            else:
                X_sum = np.sum(X**2, axis=1)
                Y_sum = np.sum(Y**2, axis=1)
                cross_term = -2 * np.dot(X, Y.T)
                sq_dists = X_sum[:, np.newaxis] + Y_sum + cross_term
            
            return np.exp(-sq_dists / bandwidth)
        
        # Compute all kernel terms at once
        K_pp = kernel_matrix(particles)
        K_rr = kernel_matrix(reference_samples)
        K_pr = kernel_matrix(particles, reference_samples)
        
        # MMD computation
        pp_mean = (np.sum(K_pp) - np.trace(K_pp)) / (len(particles) * (len(particles) - 1))
        rr_mean = (np.sum(K_rr) - np.trace(K_rr)) / (len(reference_samples) * (len(reference_samples) - 1))
        pr_mean = np.mean(K_pr)
        
        mmd = pp_mean + rr_mean - 2 * pr_mean
        return max(0, mmd)
    
    def _compute_sliced_wasserstein(self, particles, reference_samples, n_projections=20):
        """Optimized Sliced Wasserstein calculation with fewer projections"""
        # Generate random projection directions
        directions = np.random.randn(n_projections, self.state_dim)
        directions = directions / np.linalg.norm(directions, axis=1, keepdims=True)
        
        # Project all samples onto all directions at once
        particle_projs = np.dot(particles, directions.T)
        reference_projs = np.dot(reference_samples, directions.T)
        
        # Calculate Wasserstein distances for all projections
        swd = 0.0
        for i in range(n_projections):
            # Sort projections
            p_sorted = np.sort(particle_projs[:, i])
            r_sorted = np.sort(reference_projs[:, i])
            
            # Interpolate if necessary
            if len(p_sorted) != len(r_sorted):
                if len(p_sorted) > len(r_sorted):
                    indices = np.linspace(0, len(r_sorted)-1, len(p_sorted))
                    r_interp = np.interp(indices, np.arange(len(r_sorted)), r_sorted)
                    w_dist = np.mean(np.abs(p_sorted - r_interp))
                else:
                    indices = np.linspace(0, len(p_sorted)-1, len(r_sorted))
                    p_interp = np.interp(indices, np.arange(len(p_sorted)), p_sorted)
                    w_dist = np.mean(np.abs(p_interp - r_sorted))
            else:
                w_dist = np.mean(np.abs(p_sorted - r_sorted))
            
            swd += w_dist
        
        return swd / n_projections
    
    
    def _compute_correlation_error(self, particles):
        """Simplified correlation error calculation"""
        # Only use a subset of state dimensions for correlation
        key_dims = [0, 1, 4, 5, 8, 9]  # Agent and first two targets positions
        
        # Compute empirical covariance for key dimensions only
        if len(particles) < 10:
            return 1.0
        
        particles_key = particles[:, key_dims]
        empirical_cov = np.cov(particles_key, rowvar=False)
        
        # Get corresponding true correlation submatrix
        true_corr_full = self._get_correlation_matrix()
        true_corr = true_corr_full[np.ix_(key_dims, key_dims)]
        
        # Normalize to correlation matrix
        d = np.sqrt(np.diag(empirical_cov))
        d_mat = np.outer(d, d)
        d_mat[d_mat < 1e-10] = 1.0
        empirical_corr = empirical_cov / d_mat
        
        # Compute difference
        diff_norm = np.linalg.norm(empirical_corr - true_corr, 'fro')
        true_norm = np.linalg.norm(true_corr, 'fro')
        
        return diff_norm / true_norm if true_norm > 1e-10 else 1.0
    
    def _compute_mode_coverage(self, particles, true_state):
        """Compute mode coverage ratio"""
        # Consider agent and target positions for mode detection
        pos_indices = []
        
        # Agent position (first 2 dimensions)
        pos_indices.extend([0, 1])
        
        # Target positions (every 4 dimensions starting at index 4)
        for i in range(self.env_params["n_targets"]):
            pos_idx = 4 + 4*i
            pos_indices.extend([pos_idx, pos_idx+1])
        
        # Extract position components
        pos_particles = particles[:, pos_indices]
        
        # Cluster to find modes using DBSCAN
        try:
            clustering = DBSCAN(eps=1.0, min_samples=3).fit(pos_particles)
            labels = clustering.labels_
            
            # Number of detected modes
            n_modes = len(set(labels)) - (1 if -1 in labels else 0)
            
            # Count particles in each mode
            mode_counts = {}
            for label in labels:
                if label != -1:  # Skip noise points
                    mode_counts[label] = mode_counts.get(label, 0) + 1
            
            # True position components
            true_pos = np.array([true_state[idx] for idx in pos_indices])
            
            # Find distances from true position to each particle
            dists = np.sqrt(np.sum((pos_particles - true_pos)**2, axis=1))
            
            # Check if at least some particles are close to true position
            close_particles = np.sum(dists < 2.0)  # Within 2 units (adjusted for high-dim space)
            
            # At least 3% of particles should be close to true position
            true_pos_covered = close_particles >= max(3, 0.03 * len(particles))
            
            # Calculate visibility at agent's true position to evaluate expected multi-modality
            agent_pos = true_state[0:2]
            
            # Check visibility of targets to determine expected multi-modality
            low_visibility_count = 0
            for i in range(self.env_params["n_targets"]):
                pos_idx = 4 + 4*i
                target_pos = true_state[pos_idx:pos_idx+2]
                
                if hasattr(self.env, '_get_visibility'):
                    visibility, confusion, occlusion = self.env._get_visibility(agent_pos, target_pos)
                    if visibility < 0.5 or occlusion or confusion:
                        low_visibility_count += 1
                else:
                    # Default logic if method not available
                    dist = np.linalg.norm(target_pos - agent_pos)
                    if dist > self.env_params["map_size"] / 3:
                        low_visibility_count += 1
            
            # In zones with poor visibility or occlusion, multiple modes are expected
            expected_modes = 1 + (low_visibility_count > 0)
            
            # If the number of modes matches expectations
            if true_pos_covered:
                if n_modes >= expected_modes:
                    # Good coverage with appropriate number of modes
                    return 1.0
                else:
                    # True position covered but insufficient modes
                    return 0.7
            else:
                # True position not covered
                return 0.0
        except Exception as e:
            print(f"Error computing mode coverage: {e}")
            return 0.0
    
    def _compute_ess(self, particles, true_state):
        """
        Fixed ESS calculation that doesn't require _current_step attribute
        
        Args:
            particles: Belief particles
            true_state: True state for comparison
            
        Returns:
            Effective Sample Size ratio
        """
        # Generate observation from true state (only once per call)
        expected_obs = self._generate_observation(true_state)
        
        # Compute observation likelihood for each particle
        log_weights = np.zeros(len(particles))
        
        for i, particle in enumerate(particles):
            try:
                # Always provide both particle and observation
                likelihood = self.observation_model(particle, expected_obs)
                # Ensure likelihood is valid
                likelihood = max(likelihood, 1e-10)
                log_weights[i] = np.log(likelihood)
            except Exception as e:
                # If there's an error, assign very low probability
                log_weights[i] = -1e10
        
        # Normalize log weights
        log_weights -= np.max(log_weights)
        weights = np.exp(log_weights)
        sum_weights = np.sum(weights)
        
        if sum_weights < 1e-10:
            return 0.0
            
        weights /= sum_weights
        
        # Compute ESS
        ess = 1.0 / np.sum(weights**2)
        
        # Normalize by number of particles
        return ess / len(particles)
    
    def _generate_observation(self, state):
        """Generate an observation from a state"""
        # If the environment has an observation function, use it
        if hasattr(self, 'observation_model'):
            return self.observation_model(state, None)
        
        # Otherwise, create a simple observation
        agent_pos = state[0:2]
        
        # Create observation vector
        obs_dim = 2 + 2*self.env_params["n_targets"]
        observation = np.zeros(obs_dim)
        
        # Agent position with small noise
        observation[0:2] = agent_pos + np.random.normal(0, 0.01, 2)
        
        # Target positions
        for i in range(self.env_params["n_targets"]):
            pos_idx = 4 + 4*i
            target_pos = state[pos_idx:pos_idx+2]
            
            # Observation indices
            obs_idx = 2 + 2*i
            
            # Simple distance-based visibility
            dist = np.linalg.norm(target_pos - agent_pos)
            visibility = max(0.1, 1.0 - dist / self.env_params["map_size"])
            
            # Apply visibility
            if np.random.random() > visibility:
                observation[obs_idx:obs_idx+2] = np.zeros(2)
            else:
                noise_scale = self.env_params["noise_level"] * (1.0 - visibility) + 0.01
                observation[obs_idx:obs_idx+2] = target_pos + np.random.normal(0, noise_scale, 2)
        
        return observation
    
    def _visualize_beliefs(self, particles_dict, episode, step):
        """Visualize current belief particles for all methods"""
        n_methods = len(particles_dict)
        
        # Create a single figure
        fig = plt.figure(figsize=(15, 5 * n_methods))
        fig.suptitle(f"Episode {episode+1}, Step {step+1} - Belief Visualization", fontsize=16)
        
        # Create a grid of subplots
        gridspec = fig.add_gridspec(n_methods, 4)
        
        for i, (method_name, particles) in enumerate(particles_dict.items()):
            # Create a row header for this method
            row_title = fig.add_subplot(gridspec[i, 0])
            row_title.text(0.5, 0.5, f"{method_name}", fontsize=14, 
                        ha='center', va='center')
            row_title.axis('off')
            
            # Define the projections to show
            # 1: Agent position (dims 0-1)
            # 2: First target position (dims 4-5)
            # 3: Second target position (dims 8-9)
            projections = [(0, 1), (4, 5), (8, 9)]
            
            for j, (dim1, dim2) in enumerate(projections):
                ax = fig.add_subplot(gridspec[i, j+1])
                
                # Draw environment boundaries
                ax.add_patch(Rectangle((0, 0), self.env_params["map_size"], 
                                    self.env_params["map_size"], 
                                    fill=False, edgecolor='black'))
                
                # Draw visibility zones if available
                if hasattr(self.env, 'visibility_zones'):
                    for zone in self.env.visibility_zones:
                        color = 'green' if zone.get('visibility', 1.0) > 0.7 else 'red'
                        alpha = 0.2 if not zone.get('occlusion', False) else 0.5
                        circle = Circle(zone['center'], zone['radius'], 
                                       color=color, alpha=alpha)
                        ax.add_patch(circle)
                
                # Draw goal for agent position projection
                if j == 0:
                    ax.scatter(self.env.goal[0], self.env.goal[1], c='green', marker='*', 
                             s=100, label='Goal')
                
                # Draw true state
                if dim1 < len(self.true_state) and dim2 < len(self.true_state):
                    if j == 0:
                        # Agent
                        ax.scatter(self.true_state[dim1], self.true_state[dim2], c='blue', marker='o', 
                                 s=50, label='Agent')
                    else:
                        # Target
                        target_idx = (dim1 - 4) // 4
                        ax.scatter(self.true_state[dim1], self.true_state[dim2], c='red', marker='d',
                                 s=50, label=f'Target {target_idx+1}')
                
                # Draw particles
                if dim1 < particles.shape[1] and dim2 < particles.shape[1]:
                    ax.scatter(particles[:, dim1], particles[:, dim2], c='orange', marker='.', 
                             s=10, alpha=0.5, label='Belief')
                    
                    # Create a 2D histogram for density visualization
                    try:
                        heatmap, xedges, yedges = np.histogram2d(
                            particles[:, dim1], particles[:, dim2], bins=20,
                            range=[[0, self.env_params["map_size"]], 
                                [0, self.env_params["map_size"]]])
                        
                        # Plot heatmap with transparency
                        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
                        ax.imshow(heatmap.T, extent=extent, origin='lower', 
                                 cmap='YlOrRd', alpha=0.3, interpolation='bilinear')
                    except:
                        pass  # Skip heatmap if there's an error
                
                # Set plot labels and title
                ax.set_xlim(0, self.env_params["map_size"])
                ax.set_ylim(0, self.env_params["map_size"])
                
                if j == 0:
                    ax.set_title('Agent Position')
                else:
                    target_idx = (dim1 - 4) // 4
                    ax.set_title(f'Target {target_idx+1} Position')
                
                # Add grid
                ax.grid(True)
                
                # Only add legend for the first subplot
                if j == 0:
                    ax.legend(loc='upper right')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "plots", 
                                f"belief_ep{episode+1}_step{step+1}.png"))
        plt.close(fig)
    
    def _visualize_results_summary(self, results_df):
        """
        Visualize summary of experiment results.
        
        Args:
            results_df: DataFrame with evaluation results
        """
        # Aggregate results across episodes
        summary = results_df.groupby('Method').mean().reset_index()
        
        # Plot metrics comparison
        metrics = [
            'Final Position Error', 'MMD', 'Sliced Wasserstein', 
            'Correlation Error', 'Mode Coverage', 'ESS', 'Success'
        ]
        
        fig = plt.figure(figsize=(15, 10))
        fig.suptitle("Performance Comparison of Belief Approximation Methods", fontsize=16)
        
        for i, metric in enumerate(metrics):
            ax = fig.add_subplot(2, 4, i+1)
            
            # Sort by metric (asc or desc depending on metric)
            if metric in ['Mode Coverage', 'ESS', 'Success']:
                # Higher is better
                sorted_df = summary.sort_values(metric, ascending=False)
            else:
                # Lower is better
                sorted_df = summary.sort_values(metric)
            
            # Create bar plot
            bars = ax.bar(sorted_df['Method'], sorted_df[metric])
            
            # Add value annotations
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height, 
                       f'{height:.3f}', ha='center', va='bottom', rotation=0)
            
            # Set labels and title
            ax.set_xlabel('Method')
            ax.set_ylabel(metric)
            
            if metric in ['Mode Coverage', 'ESS', 'Success']:
                ax.set_title(f"{metric} (higher is better)")
                if metric == 'Success':
                    ax.set_ylim(0, 1.1)
            else:
                ax.set_title(f"{metric} (lower is better)")
                
            ax.grid(axis='y', alpha=0.3)
            plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "metrics_comparison.png"))
        plt.close(fig)
        
        # Create combined metrics plot
        # Normalize metrics for combined view
        normalized_df = summary.copy()
        
        for metric in metrics:
            if metric in ['Mode Coverage', 'ESS', 'Success']:
                # Higher is better, already normalized
                normalized_df[f"{metric}_norm"] = normalized_df[metric]
            else:
                # Lower is better, invert normalization
                if normalized_df[metric].max() != normalized_df[metric].min():
                    normalized_df[f"{metric}_norm"] = 1 - (normalized_df[metric] - normalized_df[metric].min()) / (normalized_df[metric].max() - normalized_df[metric].min())
                else:
                    normalized_df[f"{metric}_norm"] = 1.0
        
        # Plot radar chart for each method
        fig = plt.figure(figsize=(15, 10))
        fig.suptitle("Normalized Performance Metrics (higher is better)", fontsize=16)
        
        # Number of metrics
        N = len(metrics)
        
        # Angle of each axis
        angles = np.linspace(0, 2*np.pi, N, endpoint=False).tolist()
        angles += angles[:1]  # Close the loop
        
        # Position for each metric label
        label_angles = np.linspace(0, 2*np.pi, N, endpoint=False)
        label_positions = []
        for angle in label_angles:
            label_positions.append((np.cos(angle), np.sin(angle)))
        
        # Plot each method as a separate radar chart
        n_methods = len(normalized_df)
        n_cols = min(3, n_methods)
        n_rows = (n_methods + n_cols - 1) // n_cols
        
        for i, (_, row) in enumerate(normalized_df.iterrows()):
            ax = fig.add_subplot(n_rows, n_cols, i+1, projection='polar')
            
            # Plot metrics
            values = [row[f"{metric}_norm"] for metric in metrics]
            values += values[:1]  # Close the loop
            
            ax.plot(angles, values, linewidth=2, linestyle='solid')
            ax.fill(angles, values, alpha=0.25)
            
            # Set labels
            ax.set_xticks(angles[:-1])
            ax.set_xticklabels(metrics)
            
            # Adjust label positions for readability
            for label, angle in zip(ax.get_xticklabels(), label_angles):
                x, y = label.get_position()
                lab = ax.get_xticklabels()[list(label_angles).index(angle)]
                lab.set_horizontalalignment('center')
                lab.set_verticalalignment('center')
                lab.set_rotation(np.degrees(angle))
            
            ax.set_ylim(0, 1)
            ax.set_title(row['Method'])
            
            # Add gridlines
            ax.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, "radar_comparison.png"))
        plt.close(fig)
