import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Circle, Wedge
import matplotlib.lines as mlines
import matplotlib.image as mpimg
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import os

class KidnappedRobotVisualizer:
    """
    # Clean and fancy visualization for the Kidnapped Robot POMDP environment
    
    Uses imported PNG sprites for clean, professional visualization
    """
    
    def __init__(self, map_size=20, sensor_range=5):
        # Environment parameters
        self.map_size = map_size
        self.sensor_range = sensor_range
        
        # Robot position (true location unknown to robot)
        self.robot_x, self.robot_y = 4.0, 4.0
        self.robot_theta = np.pi/4  # 45 degrees
        
        # Create landmark patterns from actual implementation
        self.landmark_patterns = self._create_actual_landmark_patterns()
        
        # Load marker images
        self.markers = self._load_markers()
        
    def _load_markers(self):
        """Load marker images from markers folder"""
        markers = {}
        script_dir = os.path.dirname(os.path.abspath(__file__))
        markers_dir = os.path.join(script_dir, "markers")
        
        marker_files = {
            'robot': 'robot.png',
            'house_red': 'house_red.png', 
            'shop_blue': 'shop_blue.png',
            'warehouse_green': 'warehouse_green.png'
        }
        
        for marker_name, filename in marker_files.items():
            filepath = os.path.join(markers_dir, filename)
            if os.path.exists(filepath):
                try:
                    markers[marker_name] = mpimg.imread(filepath)
                    print(f"Loaded {marker_name}: {filename}")
                except Exception as e:
                    print(f"Could not load {filename}: {e}")
                    markers[marker_name] = None
            else:
                print(f"Marker file not found: {filepath}")
                markers[marker_name] = None
        
        return markers
        
    def _create_actual_landmark_patterns(self):
        """Create landmark patterns with more buildings and better mixing"""
        # Create many more mixed building patterns across the map
        patterns = [
            # Original patterns from implementation (kept for authenticity)
            {'landmarks': [(2, 2), (3, 4), (5, 3)], 'type': 'house'},      
            {'landmarks': [(2, 2), (4, 3), (3, 5)], 'type': 'house'},      
            {'landmarks': [(7, 2), (8, 4), (10, 3)], 'type': 'shop'},    
            {'landmarks': [(7, 2), (9, 3), (8, 5)], 'type': 'shop'},     
            {'landmarks': [(12, 12), (13, 14), (15, 13)], 'type': 'warehouse'}, 
            {'landmarks': [(12, 12), (14, 13), (13, 15)], 'type': 'warehouse'}, 
            {'landmarks': [(17, 17), (18, 19), (19, 18)], 'type': 'house'}, 
            
            # Additional mixed buildings for better coverage
            # Top area
            {'landmarks': [(1, 18), (3, 19), (5, 18)], 'type': 'shop'},
            {'landmarks': [(8, 17), (10, 19), (12, 18)], 'type': 'house'},
            {'landmarks': [(15, 19), (17, 18), (19, 19)], 'type': 'warehouse'},
            
            # Left area  
            {'landmarks': [(0.5, 15), (1, 13), (2.5, 14)], 'type': 'warehouse'},
            {'landmarks': [(1, 10), (2, 8), (0.5, 6)], 'type': 'shop'},
            {'landmarks': [(2, 5), (1, 3), (3, 1)], 'type': 'house'},
            
            # Right area
            {'landmarks': [(18, 15), (19, 13), (17.5, 11)], 'type': 'house'},
            {'landmarks': [(19, 8), (18, 6), (19.5, 4)], 'type': 'warehouse'},
            {'landmarks': [(17, 2), (18.5, 1), (19, 0.5)], 'type': 'shop'},
            
            # Center area - more mixed
            {'landmarks': [(9, 15), (11, 16), (10, 14)], 'type': 'house'},
            {'landmarks': [(6, 12), (8, 13), (7, 11)], 'type': 'warehouse'},
            {'landmarks': [(14, 10), (16, 11), (15, 9)], 'type': 'shop'},
            {'landmarks': [(5, 8), (7, 9), (6, 7)], 'type': 'house'},
            {'landmarks': [(11, 6), (13, 7), (12, 5)], 'type': 'warehouse'},
            {'landmarks': [(16, 4), (18, 5), (17, 3)], 'type': 'shop'},
            
            # Bottom area
            {'landmarks': [(4, 1), (6, 2), (5, 0.5)], 'type': 'warehouse'},
            {'landmarks': [(9, 0.5), (11, 1), (10, 2)], 'type': 'house'},
            {'landmarks': [(14, 1), (16, 2), (15, 0.5)], 'type': 'shop'},
            
            # Additional scattered buildings
            {'landmarks': [(0.5, 8.5), (1.5, 9.5)], 'type': 'house'},
            {'landmarks': [(3.5, 6.5), (4.5, 7.5)], 'type': 'shop'},
            {'landmarks': [(6.5, 4.5), (7.5, 5.5)], 'type': 'warehouse'},
            {'landmarks': [(9.5, 8.5), (10.5, 9.5)], 'type': 'shop'},
            {'landmarks': [(13.5, 6.5), (14.5, 7.5)], 'type': 'house'},
            {'landmarks': [(16.5, 8.5), (17.5, 9.5)], 'type': 'warehouse'},
            {'landmarks': [(2.5, 11.5), (3.5, 12.5)], 'type': 'warehouse'},
            {'landmarks': [(8.5, 12.5), (9.5, 13.5)], 'type': 'shop'},
            {'landmarks': [(15.5, 14.5), (16.5, 15.5)], 'type': 'house'},
            
            # Corner buildings
            {'landmarks': [(0.5, 0.5)], 'type': 'house'},
            {'landmarks': [(19.5, 0.5)], 'type': 'shop'},
            {'landmarks': [(0.5, 19.5)], 'type': 'warehouse'},
            {'landmarks': [(19.5, 19.5)], 'type': 'house'},
            
            # Mid-edge buildings
            {'landmarks': [(10, 0.2)], 'type': 'shop'},
            {'landmarks': [(19.8, 10)], 'type': 'warehouse'},
            {'landmarks': [(10, 19.8)], 'type': 'house'},
            {'landmarks': [(0.2, 10)], 'type': 'shop'},
        ]
        
        return patterns
    
    def create_domain_visualization(self):
        """Create clean kidnapped robot visualization"""
        
        fig, ax = plt.subplots(1, 1, figsize=(16, 16))
        
        # Create clean grid background
        self._create_clean_background(ax)
        
        # Plot landmarks using imported sprites
        self._plot_landmarks_with_sprites(ax)
        
        # Plot robot using imported sprite
        self._plot_robot_with_sprite(ax)
        
        # Plot sensor coverage (clean version)
        self._plot_clean_sensor_coverage(ax)
        
        # Plot belief particles using imported sprites
        self._plot_belief_particles_with_sprites(ax)
        
        # Add legend
        # self._add_clean_legend(ax)
        
        # Clean up the plot
        ax.set_xlim(0, self.map_size)
        ax.set_ylim(0, self.map_size)
        ax.set_aspect('equal')
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Add title
        # ax.set_title('Kidnapped Robot Problem: Multi-Modal Localization Challenge', 
        #              fontsize=18, fontweight='bold', pad=20)
        
        plt.tight_layout()
        
        # Save the figure
        script_dir = os.path.dirname(os.path.abspath(__file__))
        save_path = os.path.join(script_dir, 'kidnapped_robot_domain.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to: {save_path}")
        plt.show()
    
    def _create_clean_background(self, ax):
        """Create clean continuous space background (no grid)"""
        # Single background color for continuous space
        background = patches.Rectangle((0, 0), self.map_size, self.map_size, 
                                     facecolor='#F8F8FF', edgecolor='none')
        ax.add_patch(background)
        
        # Add subtle border
        border = patches.Rectangle((0, 0), self.map_size, self.map_size, 
                                 fill=False, edgecolor='#D3D3D3', linewidth=2)
        ax.add_patch(border)
    
    def _plot_landmarks_with_sprites(self, ax):
        """Plot landmarks using imported sprite images"""
        
        for pattern_data in self.landmark_patterns:
            landmarks = pattern_data['landmarks']
            building_type = pattern_data['type']
            
            # Choose appropriate building sprite based on type
            if building_type == 'house' and self.markers['house_red'] is not None:
                sprite = self.markers['house_red']
                fallback_color = '#FF6B6B'
            elif building_type == 'shop' and self.markers['shop_blue'] is not None:
                sprite = self.markers['shop_blue']
                fallback_color = '#4ECDC4'
            elif building_type == 'warehouse' and self.markers['warehouse_green'] is not None:
                sprite = self.markers['warehouse_green']
                fallback_color = '#96CEB4'
            else:
                sprite = None
                fallback_color = '#DDA0DD'
            
            for x, y in landmarks:
                if sprite is not None:
                    # Display sprite image
                    ax.imshow(sprite, extent=[x-0.4, x+0.4, y-0.4, y+0.4], zorder=10)
                else:
                    # Fallback to simple colored rectangle if sprite not available
                    rect = patches.Rectangle((x-0.3, y-0.3), 0.6, 0.6, 
                                           facecolor=fallback_color, 
                                           edgecolor='black', linewidth=2)
                    ax.add_patch(rect)
    
    def _plot_robot_with_sprite(self, ax):
        """Plot robot using imported sprite image"""
        robot_x, robot_y = self.robot_x, self.robot_y
        
        if self.markers['robot'] is not None:
            # Display robot sprite
            ax.imshow(self.markers['robot'], 
                     extent=[robot_x-0.4, robot_x+0.4, robot_y-0.4, robot_y+0.4], 
                     zorder=15)
        else:
            # Fallback robot if sprite not available
            robot_rect = patches.Rectangle((robot_x-0.3, robot_y-0.3), 0.6, 0.6, 
                                         facecolor='#FF4500', edgecolor='#8B0000', 
                                         linewidth=3)
            ax.add_patch(robot_rect)
            
            # Add simple face
            ax.text(robot_x, robot_y, '🤖', fontsize=20, 
                   ha='center', va='center', zorder=16)
        
        # Direction arrow (always show regardless of sprite)
        ax.arrow(robot_x, robot_y, 
                0.6 * np.cos(self.robot_theta), 0.6 * np.sin(self.robot_theta),
                head_width=0.12, head_length=0.12, fc='red', ec='darkred', 
                linewidth=2, zorder=20)
    
    def _plot_clean_sensor_coverage(self, ax):
        """Plot clean sensor range and field of view"""
        robot_center = (self.robot_x, self.robot_y)
        
        # Sensor range (clean dashed circle)
        sensor_circle = Circle(robot_center, self.sensor_range, 
                              fill=False, linestyle='--', color='lime', 
                              linewidth=3, alpha=0.8, zorder=5)
        ax.add_patch(sensor_circle)
        
        # Field of view (90 degrees, clean wedge)
        fov_angle = np.pi / 2
        fov_start = self.robot_theta - fov_angle / 2
        fov_end = self.robot_theta + fov_angle / 2
        
        fov_wedge = Wedge(robot_center, self.sensor_range, 
                         np.degrees(fov_start), np.degrees(fov_end),
                         alpha=0.2, color='lime', zorder=4)
        ax.add_patch(fov_wedge)
    
    def _plot_belief_particles_with_sprites(self, ax):
        """Plot belief particles using clean circles like in LightDark visualization"""
        np.random.seed(42)
        
        # Three belief modes near similar landmark patterns
        belief_centers = [
            (3.0, 3.5),   # Near red pattern
            (8.0, 3.5),   # Near blue pattern  
            (13.0, 13.0)  # Near green pattern
        ]
        
        colors = ['orange', 'magenta', 'cyan']
        particle_counts = [60, 40, 30]  # Different probabilities
        
        for center, color, count in zip(belief_centers, colors, particle_counts):
            # Generate particles around each belief center
            particles = np.random.normal(center, [1.0, 1.0], (count, 2))
            
            # Ensure particles stay within map bounds
            particles[:, 0] = np.clip(particles[:, 0], 0, self.map_size)
            particles[:, 1] = np.clip(particles[:, 1], 0, self.map_size)
            
            # Plot particles as clean circles (same as LightDark)
            ax.scatter(particles[:, 0], particles[:, 1], 
                      c=color, s=25, alpha=0.6, 
                      edgecolors='black', linewidth=0.5, zorder=8)
    
    # Remove the legend method since we're not using it anymore
    # def _add_clean_legend(self, fig):
    #     """Legend removed as requested"""
    #     pass

# Create and display the visualization
if __name__ == "__main__":
    visualizer = KidnappedRobotVisualizer()
    visualizer.create_domain_visualization()
