import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle
import gym
from gym import spaces

class MultiTargetTracking20DEnv(gym.Env):
    """
    # Multi-Target Tracking 20D Environment with Partial Observations
    
    This environment challenges belief approximation methods with high dimensionality,
    multi-modality, and complex correlation structures in POMDPs.
    
    ## State space (20D):
    - Agent position (x, y) and velocity (vx, vy) [4D]
    - 4 targets with position (x, y) and velocity (vx, vy) each [16D]
    
    ## Key challenges:
    1. High dimensionality (20D state space)
    2. Multi-modality (limited sensing, occlusions, identity confusion)
    3. Complex correlations (physical constraints, flow patterns, temporal dependencies)
    """
    
    def __init__(self, map_size=10, n_targets=4, noise_level=0.5):
        # Environment parameters
        self.map_size = map_size
        self.n_targets = n_targets  # Fixed at 4 for 20D state space
        self.noise_level = noise_level
        self.state_dim = 4 + 4*n_targets  # Total state dimensions (20D)
        
        # Define action and observation spaces
        # 8 actions: +/- acceleration in x/y direction for agent
        self.action_space = spaces.Discrete(4)
        
        # Observation space: partial observations of targets
        # Agent position (2D) + targets positions (2D * n_targets)
        self.observation_space = spaces.Box(
            low=-float('inf'), high=float('inf'), 
            shape=(2 + 2*n_targets,), dtype=np.float32)
        
        # Define visibility zones to create multi-modality
        self.visibility_zones = self._create_visibility_zones()
        
        # Define flow patterns that create correlations
        self.flow_patterns = self._create_flow_patterns()
        
        # Define correlation matrix
        self.correlation_matrix = self._define_correlation_matrix()
        
        # Initialize the state
        self.state = self._get_initial_state()
        
        # Goal position
        self.goal = np.array([self.map_size * 0.8, self.map_size * 0.8])
    
    def _create_visibility_zones(self):
        """Create zones with different visibility properties to induce multi-modality"""
        zones = []
        
        # Zone 1: Normal visibility
        zones.append({
            'center': np.array([self.map_size/4, self.map_size/4]),
            'radius': self.map_size/5,
            'visibility': 1.0  # Full visibility
        })
        
        # Zone 2: Poor visibility (creates uncertainty)
        zones.append({
            'center': np.array([3*self.map_size/4, self.map_size/4]),
            'radius': self.map_size/5,
            'visibility': 0.3  # Low visibility
        })
        
        # Zone 3: Identity confusion zone (creates multi-modality)
        zones.append({
            'center': np.array([self.map_size/4, 3*self.map_size/4]),
            'radius': self.map_size/5,
            'visibility': 0.7,  # Good visibility but identity confusion
            'confusion': True   # Targets may swap identities in observations
        })
        
        # Zone 4: Occlusion zone (creates multi-modality through missing observations)
        zones.append({
            'center': np.array([3*self.map_size/4, 3*self.map_size/4]),
            'radius': self.map_size/5,
            'visibility': 0.0,  # No visibility
            'occlusion': True   # Complete occlusion
        })
        
        return zones
    
    def _create_flow_patterns(self):
        """Create flow patterns that induce correlations in target movements"""
        # Create vector fields representing typical movement patterns
        flows = []
        
        # Pattern 1: Circular flow
        flow1 = np.zeros((self.map_size, self.map_size, 2))
        center = np.array([self.map_size/2, self.map_size/2])
        for x in range(self.map_size):
            for y in range(self.map_size):
                pos = np.array([x, y])
                rel_pos = pos - center
                # Perpendicular vector for circular motion
                if np.linalg.norm(rel_pos) > 0.1:
                    flow1[x, y, 0] = -rel_pos[1] / np.linalg.norm(rel_pos)
                    flow1[x, y, 1] = rel_pos[0] / np.linalg.norm(rel_pos)
                else:
                    flow1[x, y, 0] = 0
                    flow1[x, y, 1] = 0
        flows.append(flow1)
        
        # Pattern 2: Inward flow
        flow2 = np.zeros((self.map_size, self.map_size, 2))
        for x in range(self.map_size):
            for y in range(self.map_size):
                pos = np.array([x, y])
                rel_pos = center - pos
                norm = np.linalg.norm(rel_pos)
                if norm > 0.1:
                    flow2[x, y, 0] = rel_pos[0] / norm
                    flow2[x, y, 1] = rel_pos[1] / norm
                else:
                    flow2[x, y, 0] = 0
                    flow2[x, y, 1] = 0
        flows.append(flow2)
        
        # Pattern 3: Corridor flow
        flow3 = np.zeros((self.map_size, self.map_size, 2))
        corridor_y = self.map_size // 2
        corridor_width = self.map_size // 5
        for x in range(self.map_size):
            for y in range(self.map_size):
                if abs(y - corridor_y) < corridor_width:
                    # Strong rightward flow in corridor
                    flow3[x, y, 0] = 1.0
                    flow3[x, y, 1] = 0.0
                else:
                    # Weak random flow elsewhere
                    flow3[x, y, 0] = 0.2
                    flow3[x, y, 1] = 0.0
        flows.append(flow3)
        
        return flows
    
    def _define_correlation_matrix(self):
        """Define the correlation matrix between state variables"""
        # Initialize with identity matrix (no correlation)
        dim = self.state_dim
        corr_matrix = np.eye(dim)
        
        # Agent position-velocity correlation
        corr_matrix[0, 2] = corr_matrix[2, 0] = 0.8  # x and vx
        corr_matrix[1, 3] = corr_matrix[3, 1] = 0.8  # y and vy
        
        # Target position-velocity correlations
        for i in range(self.n_targets):
            # Position indices: 4 + 4i, 4 + 4i + 1
            # Velocity indices: 4 + 4i + 2, 4 + 4i + 3
            pos_x_idx = 4 + 4*i
            pos_y_idx = 4 + 4*i + 1
            vel_x_idx = 4 + 4*i + 2
            vel_y_idx = 4 + 4*i + 3
            
            # Correlation between position and velocity
            corr_matrix[pos_x_idx, vel_x_idx] = corr_matrix[vel_x_idx, pos_x_idx] = 0.8
            corr_matrix[pos_y_idx, vel_y_idx] = corr_matrix[vel_y_idx, pos_y_idx] = 0.8
        
        # Correlations between targets (e.g., flocking behavior)
        for i in range(self.n_targets):
            for j in range(i+1, self.n_targets):
                # Correlation between positions of different targets
                i_pos_x = 4 + 4*i
                i_pos_y = 4 + 4*i + 1
                j_pos_x = 4 + 4*j
                j_pos_y = 4 + 4*j + 1
                
                # Position correlation (flocking)
                corr_matrix[i_pos_x, j_pos_x] = corr_matrix[j_pos_x, i_pos_x] = 0.4
                corr_matrix[i_pos_y, j_pos_y] = corr_matrix[j_pos_y, i_pos_y] = 0.4
                
                # Velocity correlation (similar movement patterns)
                i_vel_x = 4 + 4*i + 2
                i_vel_y = 4 + 4*i + 3
                j_vel_x = 4 + 4*j + 2
                j_vel_y = 4 + 4*j + 3
                
                corr_matrix[i_vel_x, j_vel_x] = corr_matrix[j_vel_x, i_vel_x] = 0.6
                corr_matrix[i_vel_y, j_vel_y] = corr_matrix[j_vel_y, i_vel_y] = 0.6
                
                # Cross-correlations
                corr_matrix[i_pos_x, j_vel_x] = corr_matrix[j_vel_x, i_pos_x] = 0.3
                corr_matrix[i_pos_y, j_vel_y] = corr_matrix[j_vel_y, i_pos_y] = 0.3
        
        return corr_matrix
    
    def _get_initial_state(self):
        """Initialize the state"""
        state = np.zeros(self.state_dim)
        
        # Agent position (random in first quadrant)
        state[0:2] = np.random.uniform(0, self.map_size/3, 2)
        
        # Agent velocity (small random values)
        state[2:4] = np.random.normal(0, 0.1, 2)
        
        # Target positions (distributed around the map)
        for i in range(self.n_targets):
            # Position indices
            pos_idx = 4 + 4*i
            
            # Randomly position targets
            if i % 4 == 0:  # Top-right
                state[pos_idx:pos_idx+2] = np.random.uniform(
                    [self.map_size/2, self.map_size/2], 
                    [self.map_size, self.map_size], 2)
            elif i % 4 == 1:  # Top-left
                state[pos_idx:pos_idx+2] = np.random.uniform(
                    [0, self.map_size/2], 
                    [self.map_size/2, self.map_size], 2)
            elif i % 4 == 2:  # Bottom-left
                state[pos_idx:pos_idx+2] = np.random.uniform(
                    [0, 0], 
                    [self.map_size/2, self.map_size/2], 2)
            else:  # Bottom-right
                state[pos_idx:pos_idx+2] = np.random.uniform(
                    [self.map_size/2, 0], 
                    [self.map_size, self.map_size/2], 2)
            
            # Target velocities (small random values)
            vel_idx = pos_idx + 2
            state[vel_idx:vel_idx+2] = np.random.normal(0, 0.2, 2)
        
        return state
    
    def _apply_correlation(self, state_delta):
        """
        Fixed version of _apply_correlation that ensures the covariance matrix
        is symmetric positive-definite.
        
        Args:
            state_delta: State delta for correlation
            
        Returns:
            Correlated delta
        """
        # Convert correlation matrix to covariance matrix using state_delta as scale
        scales = np.abs(state_delta) + 0.01  # Add small constant to avoid zero scale
        cov_matrix = np.outer(scales, scales) * self.correlation_matrix
        
        # Step 1: Ensure symmetry by averaging with transpose
        cov_matrix = (cov_matrix + cov_matrix.T) / 2
        
        # Step 2: Ensure positive-definiteness with multiple approaches
        
        # First approach: Add small value to diagonal (faster)
        cov_matrix_adjusted = cov_matrix + np.eye(len(cov_matrix)) * 1e-4
        
        try:
            # Try with simple adjustment
            correlated_noise = np.random.multivariate_normal(
                mean=np.zeros(self.state_dim), cov=cov_matrix_adjusted)
        except np.linalg.LinAlgError:
            try:
                # If that fails, try fixing eigenvalues (more robust)
                eigvals, eigvecs = np.linalg.eigh(cov_matrix)
                min_eig = np.min(eigvals)
                
                if min_eig < 0:
                    # Fix negative eigenvalues
                    eigvals = np.maximum(eigvals, 1e-6)
                    fixed_cov = eigvecs @ np.diag(eigvals) @ eigvecs.T
                    
                    # Ensure symmetry again
                    fixed_cov = (fixed_cov + fixed_cov.T) / 2
                    
                    correlated_noise = np.random.multivariate_normal(
                        mean=np.zeros(self.state_dim), cov=fixed_cov)
                else:
                    # This shouldn't happen if the above failed
                    raise
            except:
                # Last resort: Use diagonal covariance
                print("Warning: Using diagonal covariance as fallback")
                diag_values = np.abs(np.diag(cov_matrix))
                diag_values[diag_values < 1e-6] = 1e-6
                correlated_noise = np.random.multivariate_normal(
                    mean=np.zeros(self.state_dim), cov=np.diag(diag_values))
        
        # Scale the noise based on the intended state_delta
        direction = np.sign(state_delta)
        magnitude = np.abs(state_delta)
        
        # Combine direction and magnitude with correlation structure
        correlated_delta = direction * (magnitude + 0.1 * correlated_noise)
        
        return correlated_delta
    
    def _get_flow_influence(self, position):
        """Get flow field influence at a given position"""
        # Discretize position
        pos_int = np.clip(position.astype(int), 0, self.map_size-1)
        
        # Randomly select a flow pattern with weights
        flow_weights = [0.5, 0.3, 0.2]  # Probabilities for each flow
        flow_idx = np.random.choice(len(self.flow_patterns), p=flow_weights)
        
        # Get flow vector at position
        flow = self.flow_patterns[flow_idx][pos_int[0], pos_int[1]]
        
        # Scale flow strength
        flow_strength = 0.03
        return flow * flow_strength
    
    def _get_visibility(self, agent_pos, target_pos):
        """Determine visibility and noise level based on zones"""
        # Default visibility
        visibility = 0.8
        confusion = False
        occlusion = False
        
        # Check each zone
        for zone in self.visibility_zones:
            # Distance to zone center
            center_dist = np.linalg.norm(target_pos - zone['center'])
            
            # If in zone radius
            if center_dist < zone['radius']:
                visibility = zone['visibility']
                confusion = zone.get('confusion', False)
                occlusion = zone.get('occlusion', False)
                break
        
        # Distance-based visibility falloff
        dist_to_agent = np.linalg.norm(target_pos - agent_pos)
        max_visibility_dist = self.map_size / 2
        if dist_to_agent > max_visibility_dist:
            visibility_falloff = max(0, 1 - (dist_to_agent - max_visibility_dist) / max_visibility_dist)
            visibility *= visibility_falloff
        
        return visibility, confusion, occlusion
    
    def step(self, action):
        """Take an action and update the state"""
        # Extract current state components
        agent_pos = self.state[0:2]
        agent_vel = self.state[2:4]
        
        # Calculate agent acceleration from action
        agent_acc = np.zeros(2)
        if action == 0:  # +x
            agent_acc[0] = 0.1
        elif action == 1:  # -x
            agent_acc[0] = -0.1
        elif action == 2:  # +y
            agent_acc[1] = 0.1
        elif action == 3:  # -y
            agent_acc[1] = -0.1
        
        # Update agent velocity with damping
        damping = 0.1
        agent_vel = agent_vel * (1 - damping) + agent_acc
        
        # Update agent position
        dt = 0.1
        agent_pos = agent_pos + agent_vel * dt
        
        # Ensure agent stays within bounds
        agent_pos = np.clip(agent_pos, 0, self.map_size)
        
        # Update state with agent information
        self.state[0:2] = agent_pos
        self.state[2:4] = agent_vel
        
        # Update each target
        for i in range(self.n_targets):
            # Extract target state
            pos_idx = 4 + 4*i
            vel_idx = pos_idx + 2
            
            target_pos = self.state[pos_idx:pos_idx+2]
            target_vel = self.state[vel_idx:vel_idx+2]
            
            # Apply flow field influence
            flow = self._get_flow_influence(target_pos)
            
            # Apply random acceleration
            target_acc = np.random.normal(0, 0.05, 2)
            
            # Update target velocity with damping and flow
            target_vel = target_vel * (1 - damping) + target_acc + flow
            
            # Update target position
            target_pos = target_pos + target_vel * dt
            
            # Ensure target stays within bounds
            target_pos = np.clip(target_pos, 0, self.map_size)
            
            # Collision avoidance with other targets
            for j in range(self.n_targets):
                if i != j:
                    other_pos_idx = 4 + 4*j
                    other_pos = self.state[other_pos_idx:other_pos_idx+2]
                    
                    # Calculate distance and direction
                    dist = np.linalg.norm(target_pos - other_pos)
                    if dist < 1.0:  # Collision avoidance radius
                        # Repulsive force
                        direction = (target_pos - other_pos) / (dist + 0.1)
                        repulsive_force = direction * (1.0 - dist) * 0.1
                        target_vel += repulsive_force
            
            # Update state with target information
            self.state[pos_idx:pos_idx+2] = target_pos
            self.state[vel_idx:vel_idx+2] = target_vel
        
        # Apply correlations to create realistic state evolution
        state_delta = self._apply_correlation(np.random.normal(0, 0.01, self.state_dim))
        self.state += state_delta
        
        # Ensure state stays within bounds
        for i in range(self.n_targets + 1):  # Agent + targets
            pos_idx = 4*i
            if pos_idx >= len(self.state):
                break
            self.state[pos_idx:pos_idx+2] = np.clip(self.state[pos_idx:pos_idx+2], 0, self.map_size)
        
        # Generate observation
        observation = self._get_observation()
        
        # Calculate reward: negative distance to goal plus small step penalty
        distance_to_goal = np.linalg.norm(agent_pos - self.goal)
        reward = -0.1 * distance_to_goal - 0.05
        
        # Collision penalty with targets
        for i in range(self.n_targets):
            pos_idx = 4 + 4*i
            target_pos = self.state[pos_idx:pos_idx+2]
            if np.linalg.norm(agent_pos - target_pos) < 0.5:
                reward -= 1.0
        
        # Check if the episode is done (reached goal or max steps)
        done = distance_to_goal < 0.5
        
        # Additional info
        info = {
            "distance_to_goal": distance_to_goal,
            "agent_position": agent_pos,
            "target_positions": [self.state[4+4*i:4+4*i+2] for i in range(self.n_targets)]
        }
        
        return observation, reward, done, info
    
    def _get_observation(self):
        """Generate partial, noisy observations of the state"""
        # True state components
        agent_pos = self.state[0:2]
        
        # Create observation vector (smaller than state)
        obs_dim = 2 + 2*self.n_targets  # position + targets
        observation = np.zeros(obs_dim)
        
        # Agent position is always observed (with small noise)
        observation[0:2] = agent_pos + np.random.normal(0, 0.01, 2)
        
        # Target observations
        for i in range(self.n_targets):
            # Extract target position
            pos_idx = 4 + 4*i
            target_pos = self.state[pos_idx:pos_idx+2]
            
            # Determine visibility properties for this target
            visibility, confusion, occlusion = self._get_visibility(agent_pos, target_pos)
            
            # Observation indices for this target
            obs_pos_idx = 2 + 2*i
            
            # Apply visibility effects
            if occlusion or np.random.random() > visibility:
                # Target not visible - fill with NaN or large noise
                observation[obs_pos_idx:obs_pos_idx+2] = np.zeros(2)
            else:
                # Apply confusion (identity swap) if in confusion zone
                if confusion and np.random.random() < 0.3:
                    # Swap with another random target
                    swap_idx = np.random.choice([j for j in range(self.n_targets) if j != i])
                    swap_pos_idx = 4 + 4*swap_idx
                    swapped_pos = self.state[swap_pos_idx:swap_pos_idx+2]
                    observation[obs_pos_idx:obs_pos_idx+2] = swapped_pos
                else:
                    # Normal observation with noise based on visibility
                    noise_scale = self.noise_level * (1.0 - visibility) + 0.01
                    observation[obs_pos_idx:obs_pos_idx+2] = target_pos + np.random.normal(0, noise_scale, 2)
        
        return observation
    
    def reset(self):
        """Reset the environment and return initial observation"""
        self.state = self._get_initial_state()
        observation = self._get_observation()
        return observation
    
    def render(self, mode='human', belief_particles=None):
        """
        Render the environment with optional belief particles.
        Shows 2D projections of the state space.
        """
        if mode == 'human':
            # Create figure with subplots
            fig, axes = plt.subplots(2, 2, figsize=(15, 12))
            axes = axes.flatten()
            
            # Main view - agent and targets
            ax = axes[0]
            
            # Draw environment boundaries
            ax.add_patch(Rectangle((0, 0), self.map_size, self.map_size, 
                                  fill=False, edgecolor='black', linewidth=2))
            
            # Draw visibility zones
            for zone in self.visibility_zones:
                color = 'green' if zone.get('visibility', 1.0) > 0.7 else 'red'
                alpha = 0.2 if not zone.get('occlusion', False) else 0.5
                circle = Circle(zone['center'], zone['radius'], 
                              color=color, alpha=alpha)
                ax.add_patch(circle)
            
            # Draw goal
            ax.scatter(self.goal[0], self.goal[1], c='green', marker='*', 
                     s=200, label='Goal')
            
            # Draw agent
            agent_pos = self.state[0:2]
            agent_vel = self.state[2:4]
            ax.scatter(agent_pos[0], agent_pos[1], c='blue', marker='o', 
                     s=100, label='Agent')
            
            # Add velocity vector for agent
            ax.arrow(agent_pos[0], agent_pos[1], agent_vel[0], agent_vel[1],
                    head_width=0.3, head_length=0.5, fc='blue', ec='blue')
            
            # Draw targets
            for i in range(self.n_targets):
                pos_idx = 4 + 4*i
                vel_idx = pos_idx + 2
                
                target_pos = self.state[pos_idx:pos_idx+2]
                target_vel = self.state[vel_idx:vel_idx+2]
                
                ax.scatter(target_pos[0], target_pos[1], c='red', marker='d',
                         s=50, label=f'Target {i+1}' if i==0 else "")
                
                # Add velocity vector for targets
                ax.arrow(target_pos[0], target_pos[1], target_vel[0], target_vel[1],
                        head_width=0.2, head_length=0.3, fc='red', ec='red')
            
            # If belief particles are provided, plot projection
            if belief_particles is not None:
                # Extract agent position belief
                agent_pos_particles = belief_particles[:, 0:2]
                
                # Plot particles
                ax.scatter(agent_pos_particles[:, 0], agent_pos_particles[:, 1],
                         c='cyan', marker='.', s=10, alpha=0.3, label='Agent Belief')
                
                # Create a 2D histogram for density visualization
                heatmap, xedges, yedges = np.histogram2d(
                    agent_pos_particles[:, 0], agent_pos_particles[:, 1],
                    bins=20, range=[[0, self.map_size], [0, self.map_size]])
                
                # Plot heatmap
                extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
                ax.imshow(heatmap.T, extent=extent, origin='lower',
                         cmap='Blues', alpha=0.3)
            
            ax.set_xlim(0, self.map_size)
            ax.set_ylim(0, self.map_size)
            ax.set_title('Agent and Target Positions')
            ax.legend()
            ax.grid(True)
            
            # Second view - target beliefs
            ax = axes[1]
            ax.add_patch(Rectangle((0, 0), self.map_size, self.map_size, 
                                  fill=False, edgecolor='black', linewidth=2))
            
            # Draw visibility zones
            for zone in self.visibility_zones:
                color = 'green' if zone.get('visibility', 1.0) > 0.7 else 'red'
                alpha = 0.2 if not zone.get('occlusion', False) else 0.5
                circle = Circle(zone['center'], zone['radius'], 
                              color=color, alpha=alpha)
                ax.add_patch(circle)
            
            # Draw targets
            for i in range(self.n_targets):
                pos_idx = 4 + 4*i
                target_pos = self.state[pos_idx:pos_idx+2]
                ax.scatter(target_pos[0], target_pos[1], c='red', marker='d',
                         s=50, label=f'Target {i+1}' if i==0 else "")
            
            # If belief particles are provided, plot target beliefs
            if belief_particles is not None:
                # Choose one target to visualize beliefs
                target_idx = 0
                pos_idx = 4 + 4*target_idx
                
                # Extract target position belief
                target_pos_particles = belief_particles[:, pos_idx:pos_idx+2]
                
                # Plot particles
                ax.scatter(target_pos_particles[:, 0], target_pos_particles[:, 1],
                         c='orange', marker='.', s=10, alpha=0.3, 
                         label=f'Target {target_idx+1} Belief')
                
                # Create a 2D histogram for density visualization
                heatmap, xedges, yedges = np.histogram2d(
                    target_pos_particles[:, 0], target_pos_particles[:, 1],
                    bins=20, range=[[0, self.map_size], [0, self.map_size]])
                
                # Plot heatmap
                extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
                ax.imshow(heatmap.T, extent=extent, origin='lower',
                         cmap='Oranges', alpha=0.3)
            
            ax.set_xlim(0, self.map_size)
            ax.set_ylim(0, self.map_size)
            ax.set_title(f'Target Beliefs')
            ax.legend()
            ax.grid(True)
            
            # Third view - flow field visualization
            ax = axes[2]
            
            # Choose a flow field to visualize
            flow_idx = 0
            flow = self.flow_patterns[flow_idx]
            
            # Create grid for quiver plot
            X, Y = np.meshgrid(
                np.linspace(0, self.map_size, 20),
                np.linspace(0, self.map_size, 20))
            
            # Resample flow field for visualization
            U = np.zeros((20, 20))
            V = np.zeros((20, 20))
            
            for i in range(20):
                for j in range(20):
                    x = int(i * self.map_size / 20)
                    y = int(j * self.map_size / 20)
                    x = min(x, self.map_size-1)
                    y = min(y, self.map_size-1)
                    U[i, j] = flow[x, y, 0]
                    V[i, j] = flow[x, y, 1]
            
            # Draw flow field
            ax.quiver(X, Y, U, V, scale=25)
            
            # Draw targets
            for i in range(self.n_targets):
                pos_idx = 4 + 4*i
                vel_idx = pos_idx + 2
                
                target_pos = self.state[pos_idx:pos_idx+2]
                target_vel = self.state[vel_idx:vel_idx+2]
                
                ax.scatter(target_pos[0], target_pos[1], c='red', marker='d',
                         s=50, label=f'Target {i+1}' if i==0 else "")
            
            ax.set_xlim(0, self.map_size)
            ax.set_ylim(0, self.map_size)
            ax.set_title(f'Flow Field and Targets')
            ax.grid(True)
            
            # Fourth view - correlation matrix visualization
            ax = axes[3]
            
            # Visualize correlation matrix
            im = ax.imshow(self.correlation_matrix, cmap='coolwarm', vmin=-1, vmax=1)
            
            # Add colorbar
            plt.colorbar(im, ax=ax)
            
            # Add labels
            ax.set_title('State Correlation Matrix')
            ax.set_xlabel('State Dimension')
            ax.set_ylabel('State Dimension')
            
            # Add grid
            minor_ticks = np.arange(0, 20, 1)
            major_ticks = np.arange(0, 20, 4)
            
            ax.set_xticks(major_ticks)
            ax.set_yticks(major_ticks)
            ax.set_xticks(minor_ticks, minor=True)
            ax.set_yticks(minor_ticks, minor=True)
            
            ax.grid(which='minor', color='w', linestyle='-', linewidth=0.2)
            ax.grid(which='major', color='w', linestyle='-', linewidth=0.5)
            
            plt.tight_layout()
            plt.show()
