import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle, Circle
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import os

class LightDark10DVisualizer:
    """
    # Professional visualization tool for the Light Dark 10D POMDP environment
    
    Creates a single representative visualization showing the domain's key features:
    - 2D projections of the 5D position space
    - Light region intensity maps  
    - Goal location and starting region
    - Multi-modal belief challenges
    """
    
    def __init__(self, map_size=10, noise_level=0.5):
        # Environment parameters
        self.map_size = map_size
        self.noise_level = noise_level
        
        # Create light regions matching the environment
        self.light_regions = self._create_light_regions()
        
        # Goal position (in the 5D space)
        self.goal = np.array([self.map_size * 0.8] * 5)
        
        # Set up color schemes
        self.setup_color_schemes()
    
    def _create_light_regions(self):
        """Create light regions matching the environment"""
        light_regions = []
        
        # Primary light corridor along first dimension
        light_regions.append((
            np.array([self.map_size/2, 0, 0, 0, 0]),  # center
            2.0,  # radius
            0.9,  # intensity
            "Primary Corridor"
        ))
        
        # Mirror corridor in another dimension (creates ambiguity)
        light_regions.append((
            np.array([0, self.map_size/2, 0, 0, 0]),  # center
            2.0,  # radius
            0.9,  # intensity
            "Mirror Corridor"
        ))
        
        # Z-shaped light corridor in dimensions 3-4
        light_regions.append((
            np.array([0, 0, self.map_size/4, self.map_size/4, 0]),
            1.5, 0.8, "Z-Path Start"
        ))
        
        light_regions.append((
            np.array([0, 0, self.map_size/2, self.map_size/2, 0]),
            1.5, 0.8, "Z-Path Middle"
        ))
        
        light_regions.append((
            np.array([0, 0, 3*self.map_size/4, 3*self.map_size/4, 0]),
            1.5, 0.8, "Z-Path End"
        ))
        
        # Small bright spot near goal in 5th dimension
        light_regions.append((
            np.array([0, 0, 0, 0, self.map_size*0.7]),
            1.0, 1.0, "Goal Beacon"
        ))
        
        # Additional dim light in ambiguous region
        light_regions.append((
            np.array([self.map_size/3, self.map_size/3, self.map_size/3, 0, 0]),
            3.0, 0.4, "Ambiguous Region"
        ))
        
        return light_regions
    
    def setup_color_schemes(self):
        """Setup professional color schemes for visualization"""
        # Light intensity colormap - from dark to bright
        self.light_cmap = LinearSegmentedColormap.from_list(
            'light_map', 
            ['#000814', '#001D3D', '#003566', '#0077B6', '#00B4D8', '#90E0EF', '#CAF0F8', '#FFFFFF'],
            N=256
        )
        
        # Set style
        plt.style.use('default')
        sns.set_palette("husl")
    
    def _get_light_level(self, position):
        """Calculate light level at a given position"""
        light_level = 0.05  # Default dark level
        
        for center, radius, intensity, _ in self.light_regions:
            relevant_dims = center != 0
            if np.sum(relevant_dims) > 0:
                pos_subset = position[relevant_dims]
                center_subset = center[relevant_dims]
                distance = np.linalg.norm(pos_subset - center_subset) / np.sqrt(np.sum(relevant_dims))
            else:
                distance = np.linalg.norm(position - center)
            
            if distance < radius:
                contribution = intensity * (1.0 - (distance / radius) ** 2)
                light_level = max(light_level, contribution)
        
        return light_level
    
    def create_domain_visualization(self):
        """Create a single representative visualization of the Light Dark 10D environment"""
        
        # Create the main figure without title
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
        
        # Define the most representative projections
        projections = [
            (0, 1, "Dimensions 1-2: Primary & Mirror Corridors\n(Perceptual Aliasing)", ax1),
            (2, 3, "Dimensions 3-4: Z-Shaped Navigation Path\n(Sequential Waypoints)", ax2), 
            (0, 4, "Dimensions 1-5: Primary Corridor & Goal Beacon\n(Start to Goal)", ax3),
            (1, 3, "Dimensions 2-4: Cross-Correlation Example\n(Complex Dependencies)", ax4)
        ]
        
        for dim1, dim2, title, ax in projections:
            self._plot_representative_projection(ax, dim1, dim2, title)
        
        plt.tight_layout()
        
        # Save in the same folder as the script
        script_dir = os.path.dirname(os.path.abspath(__file__))
        save_path = os.path.join(script_dir, 'lightdark_10d_domain.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to: {save_path}")
        plt.show()
    
    def _plot_representative_projection(self, ax, dim1, dim2, title):
        """Plot a representative 2D projection of the 5D space"""
        # Create light intensity map
        resolution = 80
        x_grid = np.linspace(0, self.map_size, resolution)
        y_grid = np.linspace(0, self.map_size, resolution)
        light_map = np.zeros((resolution, resolution))
        
        for i, x in enumerate(x_grid):
            for j, y in enumerate(y_grid):
                pos_5d = np.zeros(5)
                pos_5d[dim1] = x
                pos_5d[dim2] = y
                light_map[j, i] = self._get_light_level(pos_5d)
        
        # Plot light intensity map
        im = ax.imshow(light_map, extent=[0, self.map_size, 0, self.map_size], 
                      origin='lower', cmap=self.light_cmap, alpha=0.9)
        
        # Add contour lines for light levels
        contour_levels = [0.1, 0.3, 0.5, 0.7, 0.9]
        X, Y = np.meshgrid(x_grid, y_grid)
        contours = ax.contour(X, Y, light_map, levels=contour_levels, 
                             colors='white', alpha=0.8, linewidths=1.5)
        
        # Plot goal position with prominent marker
        ax.scatter(self.goal[dim1], self.goal[dim2], s=400, c='lime', marker='*', 
                  edgecolors='darkgreen', linewidth=3, label='Goal', zorder=10)
        
        # Plot starting region (dark area)
        start_region = Circle((self.map_size/10, self.map_size/10), self.map_size/10, 
                             fill=False, edgecolor='red', linewidth=3, linestyle='--',
                             label='Start Region', alpha=0.8)
        ax.add_patch(start_region)
        
        # Add sample agent position and belief particles to show multi-modality
        if dim1 == 0 and dim2 == 1:  # Show multi-modality in the first projection
            # Simulate multi-modal belief in dark region
            np.random.seed(42)
            # Multiple belief modes due to perceptual aliasing
            mode1 = np.random.normal([1.5, 1.5], 0.3, (50, 2))  # Mode 1
            mode2 = np.random.normal([1.5, 5.0], 0.3, (50, 2))  # Mode 2 (mirror)
            mode3 = np.random.normal([5.0, 1.5], 0.3, (50, 2))  # Mode 3 (mirror)
            
            # Clip to map boundaries
            for mode in [mode1, mode2, mode3]:
                mode[:, 0] = np.clip(mode[:, 0], 0, self.map_size)
                mode[:, 1] = np.clip(mode[:, 1], 0, self.map_size)
            
            # Plot belief particles
            ax.scatter(mode1[:, 0], mode1[:, 1], s=8, c='orange', alpha=0.6, label='Belief Particles')
            ax.scatter(mode2[:, 0], mode2[:, 1], s=8, c='orange', alpha=0.6)
            ax.scatter(mode3[:, 0], mode3[:, 1], s=8, c='orange', alpha=0.6)
            
            # True agent position
            ax.scatter(2.5, 2.5, s=150, c='red', marker='o', 
                      edgecolors='darkred', linewidth=2, label='True Agent', zorder=9)
        
        # Highlight light regions relevant to this projection
        for center, radius, intensity, label in self.light_regions:
            if center[dim1] != 0 or center[dim2] != 0:
                # Only show if this light region is visible in this projection
                if center[dim1] != 0 and center[dim2] != 0:
                    circle = Circle((center[dim1], center[dim2]), radius, 
                                   fill=False, edgecolor='yellow', linewidth=2, alpha=0.8)
                    ax.add_patch(circle)
                    
                    # Add label for major light regions
                    if intensity > 0.7:
                        ax.annotate(label, (center[dim1], center[dim2]), 
                                   fontsize=9, ha='center', va='center', color='white',
                                   bbox=dict(boxstyle="round,pad=0.2", facecolor='black', alpha=0.7))
        
        # Formatting
        ax.set_xlim(0, self.map_size)
        ax.set_ylim(0, self.map_size)
        ax.set_xlabel(f'Position Dimension {dim1+1}', fontsize=12, fontweight='bold')
        ax.set_ylabel(f'Position Dimension {dim2+1}', fontsize=12, fontweight='bold')
        ax.set_title(title, fontsize=11, fontweight='bold', pad=10)
        ax.grid(True, alpha=0.3)
        
        # Add legend only to the first subplot
        if dim1 == 0 and dim2 == 1:
            ax.legend(loc='upper right', fontsize=10, fancybox=True, framealpha=0.9)
        
        # Add colorbar to show light intensity scale
        cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label('Light Intensity\n(Lower = More Noise)', fontsize=10, fontweight='bold')
        cbar.ax.tick_params(labelsize=9)

# Create and display the visualization
if __name__ == "__main__":
    visualizer = LightDark10DVisualizer()
    visualizer.create_domain_visualization()
