import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import gym
from gym import spaces
from sklearn.cluster import DBSCAN

class LightDark10DEnv(gym.Env):
    """
    # High-Dimensional Light-Dark Navigation Environment with Perceptual Aliasing
    
    This environment extends the classic Light-Dark problem to a 10D state space with
    correlated dimensions and perceptual aliasing creating multi-modal beliefs.
    
    ## State space (10D):
    - 5D position (x₁, x₂, x₃, x₄, x₅)
    - 5D velocity (v₁, v₂, v₃, v₄, v₅)
    
    ## Key challenges:
    1. High dimensionality (10D state space)
    2. Multi-modality (symmetric light patterns create multiple hypotheses)
    3. Strong correlations between state variables
    """
    
    def __init__(self, map_size=10, light_regions=None, noise_level=0.5):
        # Environment parameters
        self.map_size = map_size
        self.noise_level = noise_level
        
        # Define action and observation spaces
        # 10 actions: +/- force in each of the 5 dimensions
        self.action_space = spaces.Discrete(10)
        
        # Observation space (position readings with noise based on light)
        self.observation_space = spaces.Box(
            low=-float('inf'), high=float('inf'), shape=(5,), dtype=np.float32)
        
        # Create light regions (where observation noise is low)
        if light_regions is None:
            self.light_regions = self._create_light_regions()
        else:
            self.light_regions = light_regions
            
        # Initialize the state with random values
        self.state = self._get_initial_state()
        
        # Define the correlation matrix for state variables
        self.correlation_matrix = self._define_correlation_matrix()
        
        # Goal position (in the 5D space)
        self.goal = np.array([self.map_size * 0.8] * 5)
        
    def _create_light_regions(self):
        """
        Create regions with different light levels (affecting observation noise).
        
        Returns a list of (center, radius, intensity) tuples.
        """
        light_regions = []
        
        # Create symmetric light corridors to induce multi-modality
        
        # Primary light corridor along first dimension
        light_regions.append((
            np.array([self.map_size/2, 0, 0, 0, 0]),  # center
            2.0,  # radius
            0.9   # intensity (higher means less noise)
        ))
        
        # Mirror corridor in another dimension (creates ambiguity)
        light_regions.append((
            np.array([0, self.map_size/2, 0, 0, 0]),  # center
            2.0,  # radius
            0.9   # intensity
        ))
        
        # Z-shaped light corridor in dimensions 3-4
        light_regions.append((
            np.array([0, 0, self.map_size/4, self.map_size/4, 0]),
            1.5,
            0.8
        ))
        
        light_regions.append((
            np.array([0, 0, self.map_size/2, self.map_size/2, 0]),
            1.5,
            0.8
        ))
        
        light_regions.append((
            np.array([0, 0, 3*self.map_size/4, 3*self.map_size/4, 0]),
            1.5,
            0.8
        ))
        
        # Small bright spot near goal in 5th dimension
        light_regions.append((
            np.array([0, 0, 0, 0, self.map_size*0.7]),
            1.0,
            1.0
        ))
        
        # Additional dim light in ambiguous region
        light_regions.append((
            np.array([self.map_size/3, self.map_size/3, self.map_size/3, 0, 0]),
            3.0,
            0.4
        ))
        
        return light_regions
    
    def _define_correlation_matrix(self):
        """Define the correlation matrix between state variables"""
        # Initialize with identity matrix (no correlation)
        corr_matrix = np.eye(10)
        
        # Correlation between position and velocity in same dimension
        for i in range(5):
            corr_matrix[i, i+5] = corr_matrix[i+5, i] = 0.8
        
        # Correlation between adjacent positional dimensions
        for i in range(4):
            corr_matrix[i, i+1] = corr_matrix[i+1, i] = 0.5
        
        # Correlation between adjacent velocity dimensions
        for i in range(5, 9):
            corr_matrix[i, i+1] = corr_matrix[i+1, i] = 0.6
        
        # Non-obvious correlation between dimensions 1-3 and 2-4
        corr_matrix[0, 2] = corr_matrix[2, 0] = 0.4
        corr_matrix[1, 3] = corr_matrix[3, 1] = 0.4
        
        # Velocity coupling: v1 affects v2 and v3
        corr_matrix[5, 6] = corr_matrix[6, 5] = 0.7
        corr_matrix[5, 7] = corr_matrix[7, 5] = 0.5
        
        return corr_matrix
        
    def _get_initial_state(self):
        """Initialize the agent's state with random values"""
        # Place agent randomly in the dark region
        position = np.random.uniform(0, self.map_size/5, 5)
        
        # Initial velocities close to zero
        velocity = np.random.normal(0, 0.1, 5)
        
        # Combine position and velocity
        state = np.concatenate([position, velocity])
        
        return state
    
    def _get_light_level(self, position):
        """
        Calculate light level (0-1) at a given position.
        Higher values mean more light (less observation noise).
        """
        # Default dark level
        light_level = 0.05
        
        # Check each light region
        for center, radius, intensity in self.light_regions:
            # Calculate distance to light region center (only in relevant dimensions)
            relevant_dims = center != 0
            if np.sum(relevant_dims) > 0:
                # Only calculate distance in dimensions where center != 0
                pos_subset = position[relevant_dims]
                center_subset = center[relevant_dims]
                
                # Scaled Euclidean distance
                distance = np.linalg.norm(pos_subset - center_subset) / np.sqrt(np.sum(relevant_dims))
            else:
                # If all center coordinates are 0, use full Euclidean distance
                distance = np.linalg.norm(position - center)
            
            # If within light region
            if distance < radius:
                # Calculate light contribution based on distance from center
                contribution = intensity * (1.0 - (distance / radius) ** 2)
                
                # Take maximum light level from all contributing regions
                light_level = max(light_level, contribution)
        
        return light_level
    
    def _apply_correlation(self, state_delta):
        """Apply correlation structure to state updates"""
        # Convert correlation matrix to covariance matrix using state_delta as scale
        scales = np.abs(state_delta) + 0.01  # Add small constant to avoid zero scale
        cov_matrix = np.outer(scales, scales) * self.correlation_matrix
        
        # Generate correlated noise
        correlated_noise = np.random.multivariate_normal(
            mean=np.zeros(10), cov=cov_matrix)
        
        # Scale the noise based on the intended state_delta
        direction = np.sign(state_delta)
        magnitude = np.abs(state_delta)
        
        # Combine direction and magnitude with correlation structure
        correlated_delta = direction * (magnitude + 0.1 * correlated_noise)
        
        return correlated_delta
        
    def step(self, action):
        """Take an action and update the state"""
        # Extract current state components
        position = self.state[:5]
        velocity = self.state[5:]
        
        # Initialize state delta
        state_delta = np.zeros(10)
        
        # Apply force based on action
        force = np.zeros(5)
        if action < 10:
            dim = action // 2
            direction = 1 if action % 2 == 0 else -1
            force[dim] = direction * 0.1
        
        # Update velocity based on force with some damping
        damping = 0.1
        velocity_delta = force - damping * velocity
        state_delta[5:] = velocity_delta
        
        # Update position based on velocity
        dt = 0.1
        position_delta = velocity * dt
        state_delta[:5] = position_delta
        
        # Apply correlation to create a more realistic state update
        correlated_delta = self._apply_correlation(state_delta)
        
        # Update state with correlated changes
        self.state = self.state + correlated_delta
        
        # Ensure the agent stays within the map boundaries
        self.state[:5] = np.clip(self.state[:5], 0, self.map_size)
        
        # Generate observation
        observation = self._get_observation()
        
        # Calculate reward: negative distance to goal plus small step penalty
        distance_to_goal = np.linalg.norm(self.state[:5] - self.goal)
        reward = -0.1 * distance_to_goal - 0.1
        
        # Check if the episode is done (reached goal or max steps)
        done = distance_to_goal < 0.5
        
        # Additional info
        info = {
            "distance_to_goal": distance_to_goal,
            "light_level": self._get_light_level(self.state[:5])
        }
        
        return observation, reward, done, info
    
    def _get_observation(self):
        """Generate noisy observations based on the current state"""
        # Extract current position
        position = self.state[:5]
        
        # Get light level at current position (0-1, higher is brighter)
        light_level = self._get_light_level(position)
        
        # Calculate observation noise level based on light (more light = less noise)
        noise_scale = self.noise_level * (1.0 - light_level) + 0.01
        
        # Generate noisy observation
        noisy_position = position + np.random.normal(0, noise_scale, 5)
        
        # Introduce multi-modality by flipping dimensions in very dark regions
        if light_level < 0.1 and np.random.random() < 0.2:
            # Randomly choose dimensions to make ambiguous
            if np.random.random() < 0.5:
                # Create ambiguity between dimensions 0 and 1
                noisy_position[0], noisy_position[1] = noisy_position[1], noisy_position[0]
            
            if np.random.random() < 0.5:
                # Create ambiguity between dimensions 2 and 3
                noisy_position[2], noisy_position[3] = noisy_position[3], noisy_position[2]
        
        return noisy_position
    
    def reset(self):
        """Reset the environment and return the initial observation"""
        self.state = self._get_initial_state()
        observation = self._get_observation()
        return observation
    
    def render(self, mode='human', belief_particles=None):
        """
        Render the environment with optional belief particles.
        Shows 2D projections of the 5D state space.
        """
        if mode == 'human':
            # Create a 2x2 grid of 2D projections
            fig, axs = plt.subplots(2, 2, figsize=(12, 10))
            axs = axs.flatten()
            
            # Define the projections to show
            projections = [(0, 1), (2, 3), (0, 4), (1, 3)]
            
            for i, (dim1, dim2) in enumerate(projections):
                ax = axs[i]
                
                # Plot the map boundaries
                ax.add_patch(Rectangle((0, 0), self.map_size, self.map_size, 
                                      fill=False, edgecolor='black'))
                
                # Plot light regions (projections to 2D)
                light_map = np.zeros((50, 50))
                x_grid = np.linspace(0, self.map_size, 50)
                y_grid = np.linspace(0, self.map_size, 50)
                
                for x_idx, x in enumerate(x_grid):
                    for y_idx, y in enumerate(y_grid):
                        # Create a 5D position with the current 2D values
                        pos_5d = np.zeros(5)
                        pos_5d[dim1] = x
                        pos_5d[dim2] = y
                        
                        # Get light level
                        light_map[x_idx, y_idx] = self._get_light_level(pos_5d)
                
                # Plot light map
                ax.imshow(light_map.T, extent=[0, self.map_size, 0, self.map_size], 
                         origin='lower', cmap='YlGnBu', alpha=0.5)
                
                # Plot the goal position (projection)
                ax.scatter(self.goal[dim1], self.goal[dim2], c='green', marker='*', 
                         s=200, label='Goal')
                
                # Plot the agent's position (projection)
                ax.scatter(self.state[dim1], self.state[dim2], c='red', marker='o', 
                         s=100, label='Agent')
                
                # Add velocity vector
                arrow_scale = 1.0
                ax.arrow(self.state[dim1], self.state[dim2], 
                        self.state[dim1+5] * arrow_scale, self.state[dim2+5] * arrow_scale,
                        head_width=0.2, head_length=0.3, fc='red', ec='red')
                
                # If belief particles are provided, plot their projections
                if belief_particles is not None:
                    # Extract relevant dimensions from particles
                    particles_x = belief_particles[:, dim1]
                    particles_y = belief_particles[:, dim2]
                    
                    # Plot particles
                    ax.scatter(particles_x, particles_y, c='orange', marker='.', 
                             s=10, alpha=0.5, label='Belief')
                    
                    # Calculate particle density for visualization
                    if len(particles_x) > 10:
                        try:
                            # Create a 2D histogram of particle positions
                            heatmap, xedges, yedges = np.histogram2d(
                                particles_x, particles_y, bins=20, 
                                range=[[0, self.map_size], [0, self.map_size]])
                            
                            # Plot heatmap with transparency
                            extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
                            ax.imshow(heatmap.T, extent=extent, origin='lower', 
                                     cmap='YlOrRd', alpha=0.3, interpolation='bilinear')
                        except:
                            pass  # Skip heatmap if there's an error
                
                # Set plot labels and title
                ax.set_xlim(0, self.map_size)
                ax.set_ylim(0, self.map_size)
                ax.set_xlabel(f'Dimension {dim1+1}')
                ax.set_ylabel(f'Dimension {dim2+1}')
                ax.set_title(f'Dims {dim1+1}-{dim2+1} Projection')
                
                # Add grid
                ax.grid(True)
                
                # Add legend (for the first plot only)
                if i == 0:
                    ax.legend()
            
            plt.tight_layout()
            plt.show()
