import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle, FancyBboxPatch, Wedge
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.patches as patches
from cairosvg import svg2png
from PIL import Image
import io
import os

def load_svg_as_image(svg_path, size=60):
    """Load SVG and convert to image for matplotlib."""
    # Get the directory of the current script
    script_dir = os.path.dirname(os.path.abspath(__file__))
    full_path = os.path.join(script_dir, svg_path)
    
    # Check if file exists
    if not os.path.exists(full_path):
        print(f"Warning: {full_path} not found. Using placeholder.")
        return None
    
    try:
        # Convert SVG to PNG
        png_data = svg2png(
            file_obj=open(full_path, 'rb'),
            output_width=size,
            output_height=size,
            background_color='transparent'
        )
        # Convert to PIL Image
        image = Image.open(io.BytesIO(png_data))
        return image
    except Exception as e:
        print(f"Error loading SVG: {e}")
        return None

def draw_agent(ax, x, y, svg_filename, zoom=0.6):
    """Draw an agent at the specified position using SVG or fallback."""
    image = load_svg_as_image(svg_filename)
    
    if image is not None:
        # Use SVG image
        imagebox = OffsetImage(image, zoom=zoom)
        imagebox.image.axes = ax
        
        # Create annotation box
        ab = AnnotationBbox(
            imagebox,
            (x, y),
            frameon=False,
            box_alignment=(0.5, 0.5)
        )
        ax.add_artist(ab)
    else:
        # Fallback to simple marker
        if 'drone' in svg_filename:
            ax.scatter(x, y, c='blue', marker='o', s=300, 
                      edgecolors='black', linewidth=2)
        else:
            ax.scatter(x, y, c='red', marker='s', s=250, 
                      edgecolors='black', linewidth=2)

def create_mtt_domain_visualization():
    """Create a clear representative image of the Multiple Target Tracking domain."""
    
    # Environment parameters
    map_size = 10
    
    # Create figure with fixed size
    dpi = 100
    fig_width = 8
    fig_height = 8
    
    plt.rcParams['figure.dpi'] = dpi
    fig = plt.figure(figsize=(fig_width, fig_height))
    
    # Create axes with fixed position
    ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
    
    # Clear axes and set limits
    ax.clear()
    ax.set_xlim(-0.5, map_size + 0.5)
    ax.set_ylim(-0.5, map_size + 0.5)
    ax.set_aspect('equal')
    
    # Draw background
    ax.add_patch(patches.Rectangle(
        (-0.5, -0.5), map_size + 1, map_size + 1,
        facecolor='#f5f5f5'
    ))
    
    # Draw visibility zones with different characteristics
    zones = [
        {
            'center': [2.5, 2.5], 
            'radius': 2.5,  # Larger zone
            'color': (0.2, 0.8, 0.2, 0.3),  # Green
            'edge_color': 'darkgreen',
            'label': 'Clear View Zone',
            'description': 'Targets fully visible'
        },
        {
            'center': [7.5, 2.5], 
            'radius': 1.8,  # Smaller zone
            'color': (1.0, 0.6, 0.2, 0.4),  # Orange
            'edge_color': 'darkorange',
            'label': 'Limited Sensing',
            'description': '30% visibility'
        },
        {
            'center': [2.5, 7.5], 
            'radius': 2.2,  # Medium zone
            'color': (1.0, 0.9, 0.2, 0.3),  # Yellow
            'edge_color': 'goldenrod',
            'label': 'Confusion Zone',
            'description': 'Targets may swap IDs'
        },
        {
            'center': [7.5, 7.5], 
            'radius': 2.0,  # Medium zone
            'color': (0.9, 0.2, 0.2, 0.5),  # Red
            'edge_color': 'darkred',
            'label': 'Blind Spot',
            'description': 'No visibility'
        }
    ]
    
    # Draw zones
    for zone in zones:
        # Draw the zone circle
        circle = Circle(zone['center'], zone['radius'], 
                       facecolor=zone['color'], 
                       edgecolor=zone['edge_color'], 
                       linewidth=2.5,
                       linestyle='--')
        ax.add_patch(circle)
        
        # Add zone label and description
        ax.text(zone['center'][0], zone['center'][1] + 0.2, zone['label'], 
                ha='center', va='center', fontsize=10, fontweight='bold')
        ax.text(zone['center'][0], zone['center'][1] - 0.3, zone['description'], 
                ha='center', va='center', fontsize=8, style='italic')
    
    # Draw goal position (moved outside circles)
    goal_pos = [9.0, 9.0]
    ax.scatter(goal_pos[0], goal_pos[1], marker='*', s=800, 
               c='gold', edgecolors='black', linewidth=2.5, zorder=5)
    
    # Draw agent (drone) - moved to avoid zone center text
    agent_pos = [1.2, 1.2]
    draw_agent(ax, agent_pos[0], agent_pos[1], 
               os.path.join('markers', 'drone.svg'), zoom=0.8)
    
    # Draw drone's field of view (FOV)
    fov_angle = 60  # degrees
    fov_range = 3.5  # range of vision
    # Calculate FOV direction towards goal
    dx = goal_pos[0] - agent_pos[0]
    dy = goal_pos[1] - agent_pos[1]
    base_angle = np.degrees(np.arctan2(dy, dx))
    
    # Create wedge for FOV
    fov_wedge = Wedge(agent_pos, fov_range, 
                      base_angle - fov_angle/2, 
                      base_angle + fov_angle/2,
                      facecolor='cyan', alpha=0.2,
                      edgecolor='darkblue', linewidth=2)
    ax.add_patch(fov_wedge)
    
    # Draw targets (cars) in different zones - avoiding zone centers
    target_positions = [
        [6.2, 8.8],  # In blind spot - moved to corner away from center
        [1.0, 6.5],  # In confusion zone - moved to edge away from center
        [4.0, 1.5],  # In clear zone - moved away from drone and center
        [8.5, 3.5]   # In limited sensing - moved to edge away from center
    ]
    
    # Define movement directions for each car
    target_velocities = [
        [-0.4, -0.3],  # Moving left and down
        [0.5, -0.2],   # Moving right and slightly down
        [0.3, 0.4],    # Moving right and up
        [-0.5, 0.3]    # Moving left and up
    ]
    
    for pos, vel in zip(target_positions, target_velocities):
        draw_agent(ax, pos[0], pos[1], 
                   os.path.join('markers', 'car.svg'), zoom=0.6)
        
        # Normalize velocity for consistent arrow length
        vel_norm = np.sqrt(vel[0]**2 + vel[1]**2)
        arrow_scale = 0.8  # Short arrows
        arrow_x = vel[0] / vel_norm * arrow_scale
        arrow_y = vel[1] / vel_norm * arrow_scale
        
        # Draw movement arrow
        ax.arrow(pos[0], pos[1], arrow_x, arrow_y,
                 head_width=0.15, head_length=0.1, 
                 fc='darkred', ec='darkred', 
                 linestyle='--', linewidth=1.5, alpha=0.8)
    
    # Remove all legend text and corner coordinates
    
    # Remove ticks but keep border
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Add border
    for spine in ax.spines.values():
        spine.set_edgecolor('black')
        spine.set_linewidth(2)
    
    # Save the figure
    script_dir = os.path.dirname(os.path.abspath(__file__))
    output_path = os.path.join(script_dir, 'multi_target_tracking_domain.png')
    
    fig.set_size_inches(fig_width, fig_height)
    fig.savefig(output_path, dpi=dpi, bbox_inches=None, transparent=False)
    print(f"Figure saved to: {output_path}")
    
    plt.show()

# Run the visualization
if __name__ == "__main__":
    create_mtt_domain_visualization()