import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import gym
from gym import spaces
from sklearn.cluster import DBSCAN

class KidnappedRobotEnv(gym.Env):
    """
    # Kidnapped Robot Problem with Perceptual Aliasing
    
    This environment simulates a robot that has been placed in an unknown location
    within a known map. The map contains perceptually similar regions that create
    ambiguity in the robot's observations.
    
    ## State space (20D):
    - 2D position (x, y)
    - Orientation (θ)
    - Robot velocity and steering parameters (2D)
    - Sensor calibration parameters (5D)
    - Environmental feature descriptors (10D)
    
    ## Key challenges:
    1. High dimensionality (20D state space)
    2. Multi-modality (ambiguous landmarks create multiple hypotheses)
    3. Strong correlations between state variables
    """
    
    def __init__(self, map_size=20, n_landmarks=15, sensor_range=5, noise_level=0.1):
        # Environment parameters
        self.map_size = map_size
        self.n_landmarks = n_landmarks
        self.sensor_range = sensor_range
        self.noise_level = noise_level
        
        # Define action and observation spaces
        self.action_space = spaces.Discrete(4)  # Forward, Left, Right, Stay
        
        # Define observation space (sensor readings)
        self.observation_space = spaces.Box(
            low=0, high=1, shape=(self.n_landmarks*2,), dtype=np.float32)
        
        # Create map with repeating structures for perceptual aliasing
        self.map = self._create_map()
        
        # Initialize the state with random values
        self.state = self._get_initial_state()
        
        # Define the correlation matrix for state variables
        self.correlation_matrix = self._define_correlation_matrix()
        
    def _create_map(self):
        """Create a map with landmarks showing perceptual aliasing patterns"""
        map_data = {}
        
        # Create repeating patterns of landmarks for perceptual aliasing
        patterns = [
            [(2, 2), (3, 4), (5, 3)],  # Pattern 1
            [(2, 2), (4, 3), (3, 5)],  # Pattern 2 (similar to 1)
            [(7, 2), (8, 4), (10, 3)],  # Pattern 3
            [(7, 2), (9, 3), (8, 5)],   # Pattern 4 (similar to 3)
            [(12, 12), (13, 14), (15, 13)],  # Pattern 5
            [(12, 12), (14, 13), (13, 15)],  # Pattern 6 (similar to 5)
            [(17, 17), (18, 19), (19, 18)],  # Pattern 7
        ]
        
        # Additional random landmarks
        num_additional = self.n_landmarks - sum(len(p) for p in patterns)
        additional_landmarks = []
        
        for _ in range(num_additional):
            x = np.random.uniform(0, self.map_size)
            y = np.random.uniform(0, self.map_size)
            additional_landmarks.append((x, y))
        
        # Combine all landmarks
        all_landmarks = []
        for pattern in patterns:
            all_landmarks.extend(pattern)
        all_landmarks.extend(additional_landmarks)
        
        # Assign feature vectors to landmarks
        for i, (x, y) in enumerate(all_landmarks):
            # For similar patterns, create similar but not identical feature vectors
            if i < len(all_landmarks) - num_additional:
                # Determine which pattern this landmark belongs to
                pattern_idx = 0
                count = 0
                for j, pattern in enumerate(patterns):
                    if count <= i < count + len(pattern):
                        pattern_idx = j
                        break
                    count += len(pattern)
                
                # Create correlated feature vectors for similar patterns
                if pattern_idx % 2 == 1:  # For similar patterns
                    # Create a slightly perturbed version of the previous feature
                    base_feature = np.random.normal(0, 1, 10)
                    perturbed_feature = base_feature + np.random.normal(0, 0.1, 10)
                    feature = perturbed_feature / np.linalg.norm(perturbed_feature)
                else:
                    feature = np.random.normal(0, 1, 10)
                    feature = feature / np.linalg.norm(feature)
            else:
                # Random features for additional landmarks
                feature = np.random.normal(0, 1, 10)
                feature = feature / np.linalg.norm(feature)
            
            map_data[(x, y)] = feature
        
        return map_data
    
    def _define_correlation_matrix(self):
        """Define the correlation matrix between state variables"""
        # Initialize with identity matrix (no correlation)
        corr_matrix = np.eye(20)
        
        # Define correlations between state variables
        # Position (x, y) and orientation are correlated
        corr_matrix[0, 2] = corr_matrix[2, 0] = 0.6  # x and orientation
        corr_matrix[1, 2] = corr_matrix[2, 1] = 0.6  # y and orientation
        
        # Position and velocity are correlated
        corr_matrix[0, 3] = corr_matrix[3, 0] = 0.8  # x and velocity_x
        corr_matrix[1, 4] = corr_matrix[4, 1] = 0.8  # y and velocity_y
        
        # Orientation and steering are correlated
        corr_matrix[2, 4] = corr_matrix[4, 2] = 0.7  # orientation and steering
        
        # Sensor calibration parameters have internal correlations
        for i in range(5, 10):
            for j in range(5, 10):
                if i != j:
                    corr_matrix[i, j] = corr_matrix[j, i] = 0.4
        
        # Sensor calibration affects feature descriptors
        for i in range(5, 10):
            for j in range(10, 20):
                corr_matrix[i, j] = corr_matrix[j, i] = 0.3
        
        # Feature descriptors have internal correlations based on similarities
        for i in range(10, 20):
            for j in range(i+1, 20):
                corr_matrix[i, j] = corr_matrix[j, i] = 0.2
        
        return corr_matrix
        
    def _get_initial_state(self):
        """Initialize the robot's state with random values"""
        # Random position within the map
        pos_x = np.random.uniform(0, self.map_size)
        pos_y = np.random.uniform(0, self.map_size)
        
        # Random orientation
        orientation = np.random.uniform(0, 2*np.pi)
        
        # Random velocity and steering
        velocity = np.random.normal(0, 0.5)
        steering = np.random.normal(0, 0.1)
        
        # Random sensor calibration parameters
        sensor_calibration = np.random.normal(0, 0.1, 5)
        
        # Random environmental feature descriptors
        feature_descriptors = np.random.normal(0, 1, 10)
        feature_descriptors = feature_descriptors / np.linalg.norm(feature_descriptors)
        
        # Combine all state components
        state = np.concatenate([
            [pos_x, pos_y, orientation, velocity, steering],
            sensor_calibration,
            feature_descriptors
        ])
        
        return state
    
    def _apply_correlation(self, state_delta):
        """Apply correlation structure to state updates"""
        # Convert correlation matrix to covariance matrix using state_delta as scale
        scales = np.abs(state_delta) + 0.01  # Add small constant to avoid zero scale
        cov_matrix = np.outer(scales, scales) * self.correlation_matrix
        
        # Generate correlated noise
        correlated_noise = np.random.multivariate_normal(
            mean=np.zeros(20), cov=cov_matrix)
        
        # Scale the noise based on the intended state_delta
        direction = np.sign(state_delta)
        magnitude = np.abs(state_delta)
        
        # Combine direction and magnitude with correlation structure
        correlated_delta = direction * (magnitude + 0.1 * correlated_noise)
        
        return correlated_delta
        
    def step(self, action):
        """Take an action and update the state"""
        # Extract current state components
        pos_x, pos_y = self.state[0], self.state[1]
        orientation = self.state[2]
        velocity, steering = self.state[3], self.state[4]
        
        # Initialize state delta
        state_delta = np.zeros(20)
        
        # Update state based on action
        if action == 0:  # Move forward
            # Update position based on orientation and velocity
            state_delta[0] = velocity * np.cos(orientation) * 0.1  # x update
            state_delta[1] = velocity * np.sin(orientation) * 0.1  # y update
            # Small changes to velocity
            state_delta[3] = np.random.normal(0, 0.05)
        
        elif action == 1:  # Turn left
            # Update orientation (counter-clockwise)
            state_delta[2] = 0.1  # orientation update
            # Update steering
            state_delta[4] = np.random.normal(0.05, 0.02)
        
        elif action == 2:  # Turn right
            # Update orientation (clockwise)
            state_delta[2] = -0.1  # orientation update
            # Update steering
            state_delta[4] = np.random.normal(-0.05, 0.02)
        
        # For action 3 (stay), no state_delta changes
        
        # Small random changes to sensor calibration
        state_delta[5:10] = np.random.normal(0, 0.01, 5)
        
        # Apply correlation to create a more realistic state update
        correlated_delta = self._apply_correlation(state_delta)
        
        # Update state with correlated changes
        self.state = self.state + correlated_delta
        
        # Ensure the robot stays within the map boundaries
        self.state[0] = np.clip(self.state[0], 0, self.map_size)
        self.state[1] = np.clip(self.state[1], 0, self.map_size)
        
        # Normalize orientation to [0, 2π)
        self.state[2] = self.state[2] % (2*np.pi)
        
        # Normalize feature descriptors
        feature_norm = np.linalg.norm(self.state[10:20])
        if feature_norm > 0:
            self.state[10:20] = self.state[10:20] / feature_norm
        
        # Generate observation
        observation = self._get_observation()
        
        # Simple reward: -1 for each step to encourage efficient localization
        reward = -1
        
        # Check if the episode is done
        done = False
        
        # Additional info
        info = {}
        
        return observation, reward, done, info
    
    def _get_observation(self):
        """Generate noisy observations based on the current state"""
        # Extract current position and orientation
        pos_x, pos_y = self.state[0], self.state[1]
        orientation = self.state[2]
        sensor_calibration = self.state[5:10]
        
        # Initialize observation vector
        observation = np.zeros(self.n_landmarks * 2)
        
        # Calculate field of view edges
        fov_angle = np.pi / 2  # 90-degree field of view
        fov_start = orientation - fov_angle / 2
        fov_end = orientation + fov_angle / 2
        
        # For each landmark, check if it's within sensor range and field of view
        landmark_idx = 0
        for (lm_x, lm_y), feature in self.map.items():
            # Calculate distance to landmark
            dx = lm_x - pos_x
            dy = lm_y - pos_y
            distance = np.sqrt(dx**2 + dy**2)
            
            # Calculate angle to landmark
            angle = np.arctan2(dy, dx)
            
            # Normalize angle to [0, 2π)
            angle = angle % (2*np.pi)
            
            # Check if angle is within field of view
            in_fov = False
            if fov_start <= fov_end:
                in_fov = fov_start <= angle <= fov_end
            else:  # Field of view crosses the 0 angle
                in_fov = (angle >= fov_start) or (angle <= fov_end)
            
            # If landmark is within range and field of view
            if distance <= self.sensor_range and in_fov:
                # Add noise to distance based on sensor calibration
                noise_scale = self.noise_level * (1 + 0.2 * np.sum(sensor_calibration))
                noisy_distance = distance + np.random.normal(0, noise_scale)
                
                # Normalize distance observation
                normalized_distance = noisy_distance / self.sensor_range
                
                # Calculate feature similarity with the robot's internal representation
                feature_similarity = np.dot(feature, self.state[10:20])
                
                # Add noise to feature similarity
                noisy_similarity = feature_similarity + np.random.normal(0, noise_scale)
                
                # Store the observation if within limits
                if landmark_idx < self.n_landmarks:
                    observation[landmark_idx*2] = normalized_distance
                    observation[landmark_idx*2 + 1] = noisy_similarity
                    landmark_idx += 1
        
        return observation
    
    def reset(self):
        """Reset the environment and return the initial observation"""
        self.state = self._get_initial_state()
        observation = self._get_observation()
        return observation
    
    def render(self, mode='human', belief_particles=None):
        """Render the environment with optional belief particles"""
        if mode == 'human':
            plt.figure(figsize=(10, 10))
            
            # Plot the map landmarks
            for (lm_x, lm_y), _ in self.map.items():
                plt.scatter(lm_x, lm_y, c='blue', marker='^', s=100, label='Landmark')
            
            # Plot the robot
            robot_x, robot_y = self.state[0], self.state[1]
            robot_theta = self.state[2]
            
            plt.scatter(robot_x, robot_y, c='red', marker='o', s=200, label='Robot')
            
            # Plot robot orientation
            arrow_length = 0.5
            dx = arrow_length * np.cos(robot_theta)
            dy = arrow_length * np.sin(robot_theta)
            plt.arrow(robot_x, robot_y, dx, dy, head_width=0.3, head_length=0.3, fc='red', ec='red')
            
            # Plot sensor range
            circle = plt.Circle((robot_x, robot_y), self.sensor_range, fill=False, 
                              linestyle='--', color='green')
            plt.gca().add_patch(circle)
            
            # Plot field of view
            fov_angle = np.pi / 2  # 90-degree field of view
            theta1 = robot_theta - fov_angle / 2
            theta2 = robot_theta + fov_angle / 2
            
            # Draw field of view lines
            line_length = self.sensor_range
            plt.plot([robot_x, robot_x + line_length * np.cos(theta1)],
                     [robot_y, robot_y + line_length * np.sin(theta1)],
                     'g--')
            plt.plot([robot_x, robot_x + line_length * np.cos(theta2)],
                     [robot_y, robot_y + line_length * np.sin(theta2)],
                     'g--')
            
            # If belief particles are provided, plot them
            if belief_particles is not None:
                # Extract x, y positions from particles
                particles_x = [p[0] for p in belief_particles]
                particles_y = [p[1] for p in belief_particles]
                
                # Plot particles
                plt.scatter(particles_x, particles_y, c='orange', marker='.', 
                          s=10, alpha=0.5, label='Belief')
                
                # Calculate particle density for visualization
                if len(particles_x) > 10:
                    try:
                        # Create a 2D histogram of particle positions
                        heatmap, xedges, yedges = np.histogram2d(
                            particles_x, particles_y, bins=20, 
                            range=[[0, self.map_size], [0, self.map_size]])
                        
                        # Plot heatmap with transparency
                        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
                        plt.imshow(heatmap.T, extent=extent, origin='lower', 
                                 cmap='YlOrRd', alpha=0.3, interpolation='bilinear')
                    except:
                        pass  # Skip heatmap if there's an error
            
            # Set plot limits and labels
            plt.xlim(0, self.map_size)
            plt.ylim(0, self.map_size)
            plt.xlabel('X')
            plt.ylabel('Y')
            plt.title('Kidnapped Robot Problem')
            
            # Show legend without duplicate entries
            handles, labels = plt.gca().get_legend_handles_labels()
            by_label = dict(zip(labels, handles))
            plt.legend(by_label.values(), by_label.keys())
            
            plt.grid(True)
            plt.show()
