#!/usr/bin/env python3
"""
Annotation Preview Script
Creates a 3x4 grid visualization of random images with their bounding box annotations.
"""

import os
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import numpy as np

# Configuration
DATA_DIR = "/home/ubuntu/roadsight/data"
TRAIN_IMAGES_DIR = os.path.join(DATA_DIR, "train/images")
TRAIN_LABELS_DIR = os.path.join(DATA_DIR, "train/labels")
VAL_IMAGES_DIR = os.path.join(DATA_DIR, "val/images")
VAL_LABELS_DIR = os.path.join(DATA_DIR, "val/labels")
TEST_IMAGES_DIR = os.path.join(DATA_DIR, "test/images")
TEST_LABELS_DIR = os.path.join(DATA_DIR, "test/labels")

# Class names and colors
CLASS_NAMES = ["roundabout", "intersection"]
CLASS_COLORS = ['red', 'blue']

def load_image_and_annotations(image_name, dataset_type='train'):
    """Load an image and its corresponding annotations from specified dataset."""
    # Select directories based on dataset type
    if dataset_type == 'train':
        images_dir = TRAIN_IMAGES_DIR
        labels_dir = TRAIN_LABELS_DIR
    elif dataset_type == 'val':
        images_dir = VAL_IMAGES_DIR
        labels_dir = VAL_LABELS_DIR
    elif dataset_type == 'test':
        images_dir = TEST_IMAGES_DIR
        labels_dir = TEST_LABELS_DIR
    else:
        raise ValueError(f"Invalid dataset_type: {dataset_type}")
    
    # Paths
    image_path = os.path.join(images_dir, image_name)
    label_path = os.path.join(labels_dir, image_name.replace('.jpg', '.txt'))
    
    # Load image
    if not os.path.exists(image_path):
        return None, None
    
    image = Image.open(image_path)
    img_width, img_height = image.size
    
    # Load annotations
    annotations = []
    if os.path.exists(label_path):
        with open(label_path, 'r') as f:
            for line in f.readlines():
                line = line.strip()
                if line:
                    parts = line.split()
                    class_id = int(float(parts[0]))  # Handle both int and float formats
                    x_center = float(parts[1])
                    y_center = float(parts[2])
                    width = float(parts[3])
                    height = float(parts[4])
                    
                    # Convert YOLO format to absolute coordinates
                    x_center_abs = x_center * img_width
                    y_center_abs = y_center * img_height
                    width_abs = width * img_width
                    height_abs = height * img_height
                    
                    # Calculate top-left corner
                    x_min = x_center_abs - width_abs / 2
                    y_min = y_center_abs - height_abs / 2
                    
                    annotations.append({
                        'class_id': class_id,
                        'class_name': CLASS_NAMES[class_id],
                        'bbox': (x_min, y_min, width_abs, height_abs),
                        'color': CLASS_COLORS[class_id]
                    })
    
    return image, annotations

def plot_image_with_annotations(ax, image, annotations):
    """Plot an image with its bounding box annotations on the given axis."""
    ax.imshow(image)
    ax.axis('off')
    
    # Draw bounding boxes
    for ann in annotations:
        x_min, y_min, width, height = ann['bbox']
        
        # Create rectangle patch
        rect = patches.Rectangle(
            (x_min, y_min), width, height,
            linewidth=2, edgecolor=ann['color'], 
            facecolor='none', alpha=0.8
        )
        ax.add_patch(rect)
        
        # Add class ID label
        ax.text(
            x_min, y_min - 5, str(ann['class_id']),
            color=ann['color'], fontsize=10, fontweight='bold',
            bbox=dict(boxstyle="round,pad=0.2", facecolor='white', alpha=0.8)
        )

def create_annotation_grid(num_rows=3, num_cols=3, save_path=None, jpg_quality=85):
    """Create a grid visualization with random images from train (row 1), val (row 2), test (row 3)."""
    # Get all images from each dataset
    train_images = [f for f in os.listdir(TRAIN_IMAGES_DIR) if f.endswith('.jpg')]
    val_images = [f for f in os.listdir(VAL_IMAGES_DIR) if f.endswith('.jpg')]
    test_images = [f for f in os.listdir(TEST_IMAGES_DIR) if f.endswith('.jpg')]
    
    print(f"Available images - Train: {len(train_images)}, Val: {len(val_images)}, Test: {len(test_images)}")
    
    # Select random images for each row
    selected_images = []
    dataset_types = []
    
    # Row 1: Random images from train dataset
    train_selected = random.sample(train_images, min(num_cols, len(train_images)))
    for i in range(num_cols):
        if i < len(train_selected):
            selected_images.append(train_selected[i])
            dataset_types.append('train')
        else:
            selected_images.append(None)
            dataset_types.append('train')
    
    # Row 2: Random images from val dataset  
    val_selected = random.sample(val_images, min(num_cols, len(val_images)))
    for i in range(num_cols):
        if i < len(val_selected):
            selected_images.append(val_selected[i])
            dataset_types.append('val')
        else:
            selected_images.append(None)
            dataset_types.append('val')
    
    # Row 3: Random images from test dataset
    test_selected = random.sample(test_images, min(num_cols, len(test_images)))
    for i in range(num_cols):
        if i < len(test_selected):
            selected_images.append(test_selected[i])
            dataset_types.append('test')
        else:
            selected_images.append(None)
            dataset_types.append('test')
    
    # Create the plot with proper aspect ratio for 960x640 images
    # Calculate figure size to maintain image proportions without stretching
    img_aspect = 960 / 640  # 1.5 (width/height)
    cell_width = 3  # inches per cell
    cell_height = cell_width / img_aspect  # 2 inches per cell
    
    fig_width = num_cols * cell_width  # 6 * 3 = 18
    fig_height = num_rows * cell_height  # 3 * 2 = 6
    
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(fig_width, fig_height))
    
    # Handle case where we have fewer images than grid cells
    total_cells = num_rows * num_cols
    
    for idx in range(total_cells):
        row = idx // num_cols
        col = idx % num_cols
        ax = axes[row, col] if num_rows > 1 else axes[col]
        
        if idx < len(selected_images) and selected_images[idx] is not None:
            image_name = selected_images[idx]
            dataset_type = dataset_types[idx]
            image, annotations = load_image_and_annotations(image_name, dataset_type)
            
            if image is not None:
                plot_image_with_annotations(ax, image, annotations)
            else:
                ax.text(0.5, 0.5, f'Error loading image', 
                       ha='center', va='center', transform=ax.transAxes)
                ax.axis('off')
        else:
            # Empty cell
            ax.text(0.5, 0.5, 'No image available', 
                   ha='center', va='center', transform=ax.transAxes, fontsize=8)
            ax.axis('off')
    
    # Remove all padding and spacing between subplots - negative wspace for tighter columns
    plt.subplots_adjust(left=0, right=1, top=0.95, bottom=0.08, 
                       wspace=-0.35, hspace=0.05)
    
    # Add legend back
    legend_elements = [patches.Patch(color=color, label=f"Class {idx}: {name}") 
                      for idx, (name, color) in enumerate(zip(CLASS_NAMES, CLASS_COLORS))]
    fig.legend(handles=legend_elements, loc='lower center', ncol=len(CLASS_NAMES), 
              bbox_to_anchor=(0.5, 0.01), fontsize=12)
    
    # Save as JPG with compression control
    if save_path:
        # Save as JPG directly with specified quality
        jpg_path = save_path.replace('.png', '.jpg')
        plt.savefig(jpg_path, dpi=300, bbox_inches='tight', 
                   facecolor='white', edgecolor='none', pad_inches=0,
                   format='jpg', pil_kwargs={'quality': jpg_quality, 'optimize': True})
        print(f"Grid saved to: {jpg_path}")
        
        # Check file size
        file_size = os.path.getsize(jpg_path) / 1024  # KB
        print(f"File size: {file_size:.1f} KB (quality: {jpg_quality})")
        
        if file_size > 200:
            print(f"Warning: File size ({file_size:.1f} KB) is larger than 200 KB")
            print(f"Consider reducing jpg_quality parameter (current: {jpg_quality})")
        else:
            print(f"File size is within 200 KB limit")
    
    plt.close()  # Close the figure to free memory
    
    # Print statistics
    print(f"\nGrid Statistics:")
    print(f"- Row 1 (Train): {sum(1 for i in range(num_cols) if i < len(selected_images) and selected_images[i] is not None and dataset_types[i] == 'train')} images")
    print(f"- Row 2 (Val): {sum(1 for i in range(num_cols, 2*num_cols) if i < len(selected_images) and selected_images[i] is not None and dataset_types[i] == 'val')} images")
    print(f"- Row 3 (Test): {sum(1 for i in range(2*num_cols, 3*num_cols) if i < len(selected_images) and selected_images[i] is not None and dataset_types[i] == 'test')} images")
    
    # Count annotations by class for each dataset
    dataset_annotations = {'train': {name: 0 for name in CLASS_NAMES},
                          'val': {name: 0 for name in CLASS_NAMES},
                          'test': {name: 0 for name in CLASS_NAMES}}
    
    for idx, image_name in enumerate(selected_images):
        if image_name is not None:
            dataset_type = dataset_types[idx]
            _, annotations = load_image_and_annotations(image_name, dataset_type)
            if annotations:
                for ann in annotations:
                    dataset_annotations[dataset_type][ann['class_name']] += 1
    
    print(f"- Annotations by dataset and class:")
    for dataset_type in ['train', 'val', 'test']:
        print(f"  {dataset_type.upper()}:")
        for class_name, count in dataset_annotations[dataset_type].items():
            print(f"    * {class_name}: {count}")

def main():
    """Main function to create the annotation preview grid."""
    print("Creating 3x4 grid with random images from train (row 1), val (row 2), test (row 3)...")
    
    # Set random seed for reproducibility
    random.seed(random.randint(0, 100))
    
    # Create the grid with adjustable JPG quality (lower = smaller file)
    # Try different quality levels to find one under 200KB
    save_path = os.path.join(DATA_DIR, "annotation_preview.png")  # Will be saved as .jpg
    
    # Try progressively lower quality until we get under 200KB
    create_annotation_grid(num_rows=3, num_cols=4, save_path=save_path, jpg_quality=20)
    
    # Check if file size is acceptable
    jpg_path = save_path.replace('.png', '.jpg')
    file_size = os.path.getsize(jpg_path) / 1024
        


if __name__ == "__main__":
    main()
