#!/usr/bin/env python3

'''
visualize grid images from classwise_images folder
'''

import argparse
import os
import random
from pathlib import Path
from PIL import Image
import numpy as np
import re

def is_valid_image(file_path):
    """Check if the image file is valid and not corrupted."""
    try:
        with Image.open(file_path) as img:
            img.verify()
        return True
    except Exception:
        return False

def get_class_folders(classwise_images_dir):
    """Get list of class folders in classwise_images directory."""
    class_folders = []
    for item in os.listdir(classwise_images_dir):
        item_path = os.path.join(classwise_images_dir, item)
        if os.path.isdir(item_path):
            class_folders.append(item)
    return sorted(class_folders)

def get_valid_images(class_folder_path):
    """Get list of valid PNG images in the class folder."""
    valid_images = []
    if not os.path.exists(class_folder_path):
        return valid_images
    
    for file in os.listdir(class_folder_path):
        if file.lower().endswith('.png'):
            file_path = os.path.join(class_folder_path, file)
            if is_valid_image(file_path):
                valid_images.append(file_path)
    
    return valid_images

def parse_conf_from_filename(file_path):
    """Extract confidence value from filename pattern *_conf#.png.
    Returns float confidence, or None if not present.
    """
    name = os.path.basename(file_path)
    m = re.search(r"_conf([0-9]+(?:\.[0-9]+)?)", name)
    if m:
        try:
            return float(m.group(1))
        except Exception:
            return None
    return None

def select_images(image_paths, mode="conf_top", k=64, conf_threshold=0.5):
    """Select k images by mode: 'conf_top' (top-k by confidence), 'conf_thrand' (random from threshold+), or 'random'.
    - For 'conf_top', images without confidence are treated as -1 and sorted to the end.
    - For 'conf_thrand', select random images from those with confidence >= threshold.
    - If fewer than k images exist, returns all of them.
    """
    if mode == "conf_top":
        scored = []
        for p in image_paths:
            c = parse_conf_from_filename(p)
            score = c if c is not None else -1.0
            scored.append((score, p))
        scored.sort(key=lambda x: x[0], reverse=True)
        return [p for _s, p in scored[:k]]
    
    elif mode == "conf_thrand":
        # Filter images with confidence >= threshold
        threshold_images = []
        for p in image_paths:
            c = parse_conf_from_filename(p)
            if c is not None and c >= conf_threshold:
                threshold_images.append(p)
        
        if len(threshold_images) >= k:
            return random.sample(threshold_images, k)
        elif len(threshold_images) > 0:
            # Pad with None (black placeholder) to reach k images
            selected = random.sample(threshold_images, len(threshold_images))
            while len(selected) < k:
                selected.append(None)  # None will be converted to black placeholder
            return selected
        else:
            # If no images meet threshold, return all None (black placeholders)
            return [None] * k
    
    # random
    if len(image_paths) >= k:
        return random.sample(image_paths, k)
    return list(image_paths)

def calculate_avg_confidence(image_paths):
    """Calculate average confidence from image filenames.
    Returns average confidence or None if no confidence values found.
    """
    confidences = []
    for p in image_paths:
        if p is not None:  # None guard
            c = parse_conf_from_filename(p)
            if c is not None:
                confidences.append(c)
    
    if not confidences:
        return None
    
    return sum(confidences) / len(confidences)

def create_grid_image(image_paths, grid_size=8):
    """Create a grid image from the given image paths."""
    if len(image_paths) < grid_size * grid_size:
        print(f"Warning: Only {len(image_paths)} valid images found, but {grid_size * grid_size} are needed.")
        # Pad with black placeholder images if needed
        while len(image_paths) < grid_size * grid_size:
            image_paths.append(None)  # Use None to indicate black placeholder
    
    # Take exactly grid_size * grid_size images
    selected_images = image_paths[:grid_size * grid_size]
    
    # Load and resize images
    images = []
    for img_path in selected_images:
        if img_path is None:
            # Create a black placeholder image
            placeholder = Image.new('RGB', (64, 64), (0, 0, 0))
            images.append(placeholder)
        else:
            try:
                img = Image.open(img_path)
                # Resize to a standard size (assuming square images)
                img = img.resize((64, 64), Image.LANCZOS)
                images.append(img)
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                # Create a black placeholder image
                placeholder = Image.new('RGB', (64, 64), (0, 0, 0))
                images.append(placeholder)
    
    # Create grid
    grid_width = grid_size * 64
    grid_height = grid_size * 64
    grid_image = Image.new('RGB', (grid_width, grid_height))
    
    for i, img in enumerate(images):
        row = i // grid_size
        col = i % grid_size
        x = col * 64
        y = row * 64
        grid_image.paste(img, (x, y))
    
    return grid_image

def main():
    parser = argparse.ArgumentParser(description='Create a grid visualization of images from a specific class or all classes')
    parser.add_argument('--input_dir', type=str, required=True, 
                       help='Input directory containing classwise_images folder')
    parser.add_argument('--class_idx', type=int, default=None,
                       help='Class index to visualize (if not provided, creates grids for all classes)')
    parser.add_argument('--select', type=str, default='conf_thrand', choices=['conf_top', 'conf_thrand', 'random'],
                       help="selection method")
    parser.add_argument('--conf_threshold', type=float, default=0.7,
                       help='Confidence threshold for conf_threshold mode. Default: 0.5')
    parser.add_argument('--grid_size', type=int, default=5, 
                       help='Grid size (e.g., 8 for 8x8=64 images, 5 for 5x5=25 images). Default: 8')
    
    args = parser.parse_args()
    
    input_dir = Path(args.input_dir)
    class_idx = args.class_idx
    grid_size = args.grid_size
    total_images = grid_size * grid_size
    
    # Check if input directory exists
    if not input_dir.exists():
        print(f"Error: Input directory {input_dir} does not exist.")
        return
    
    # Check if classwise_images folder exists
    classwise_images_dir = input_dir / 'classwise_images'
    if not classwise_images_dir.exists():
        print(f"Error: classwise_images folder not found in {input_dir}")
        return
    
    # Get class folders
    class_folders = get_class_folders(classwise_images_dir)
    if not class_folders:
        print(f"Error: No class folders found in {classwise_images_dir}")
        return
    
    # Create output subdirectory for grid images
    output_subdir = input_dir / f"grid_images_{args.select}_{args.conf_threshold}"
    output_subdir.mkdir(exist_ok=True)
    
    # Always create multiple all-classes random grids first
    print("Creating all-classes random grids...")
    
    # Collect all valid images from all classes
    all_valid_images = []
    for target_class in class_folders:
        class_folder_path = classwise_images_dir / target_class
        valid_images = get_valid_images(class_folder_path)
        all_valid_images.extend(valid_images)
    
    if all_valid_images:
        print(f"Found {len(all_valid_images)} total valid images across all classes")

        num_random_grids = 5
        # Shuffle once to prefer unique, non-overlapping selections across grids when possible
        shuffled_images = list(all_valid_images)
        random.shuffle(shuffled_images)

        for grid_idx in range(num_random_grids):
            start = grid_idx * total_images
            end = start + total_images

            if end <= len(shuffled_images):
                # Use a unique, non-overlapping slice for this grid
                selected_images = shuffled_images[start:end]
            else:
                # Not enough unique images left; fall back to random sampling
                selected_images = select_images(
                    all_valid_images,
                    mode="random",
                    k=total_images,
                    conf_threshold=args.conf_threshold,
                )
                if len(all_valid_images) < total_images:
                    print(
                        f"Warning: Only {len(all_valid_images)} valid images available, using all of them (grid {grid_idx+1}/{num_random_grids})"
                    )

            # Calculate average confidence
            avg_conf = calculate_avg_confidence(selected_images)
            conf_suffix = f"_avgconf{avg_conf:.3f}" if avg_conf is not None else "_no_conf"

            # Create grid image
            print(f"Creating {grid_size}x{grid_size} grid from all classes... ({grid_idx+1}/{num_random_grids})")
            grid_image = create_grid_image(selected_images, grid_size=grid_size)

            # Save the grid image with an index suffix to differentiate
            output_filename = f"all_classes_random_{grid_idx+1}{conf_suffix}_grid.png"
            output_path = output_subdir / output_filename
            grid_image.save(output_path)

            print(f"All-classes grid image saved as: {output_path}")
    else:
        print("Warning: No valid images found for all-classes grid")
    
    # Then process individual classes if requested
    if class_idx is not None:
        # Process single class
        if class_idx < 0 or class_idx >= len(class_folders):
            print(f"Error: class_idx {class_idx} is out of range. Available classes: 0-{len(class_folders)-1}")
            print(f"Available classes: {class_folders}")
            return
        
        target_class = class_folders[class_idx]
        class_folder_path = classwise_images_dir / target_class
        
        print(f"Processing class: {target_class} (index: {class_idx})")
        
        # Get valid images from the class folder
        valid_images = get_valid_images(class_folder_path)
        if not valid_images:
            print(f"Error: No valid PNG images found in {class_folder_path}")
            return
        
        print(f"Found {len(valid_images)} valid images")
        
        # Select images by mode (default: conf_top)
        selected_images = select_images(valid_images, mode=args.select, k=total_images, conf_threshold=args.conf_threshold)
        if len(valid_images) < total_images:
            print(f"Warning: Only {len(valid_images)} valid images available, using all of them")
        elif args.select == "conf_thrand" and any(img is None for img in selected_images):
            threshold_count = sum(1 for img in selected_images if img is not None)
            print(f"Warning: Only {threshold_count} images found with confidence >= {args.conf_threshold}, padding with black images")
        
        # Calculate average confidence
        avg_conf = calculate_avg_confidence(selected_images)
        conf_suffix = f"_avgconf{avg_conf:.3f}" if avg_conf is not None else "_no_conf"
        
        # Create grid image
        print(f"Creating {grid_size}x{grid_size} grid...")
        grid_image = create_grid_image(selected_images, grid_size=grid_size)
        
        # Save the grid image
        output_filename = f"class_{class_idx}_{conf_suffix}_grid.png"
        output_path = output_subdir / output_filename
        grid_image.save(output_path)
        
        print(f"Grid image saved as: {output_path}")
    
    else:
        # Process all classes
        print(f"Processing all {len(class_folders)} classes...")
        
        for idx, target_class in enumerate(class_folders):
            class_folder_path = classwise_images_dir / target_class
            
            print(f"Processing class {idx+1}/{len(class_folders)}: {target_class}")
            
            # Get valid images from the class folder
            valid_images = get_valid_images(class_folder_path)
            if not valid_images:
                print(f"Warning: No valid PNG images found in {class_folder_path}, skipping...")
                continue
            
            print(f"Found {len(valid_images)} valid images")
            
            # Select images by mode (default: conf_top)
            selected_images = select_images(valid_images, mode=args.select, k=total_images, conf_threshold=args.conf_threshold)
            if len(valid_images) < total_images:
                print(f"Warning: Only {len(valid_images)} valid images available, using all of them")
            elif args.select == "conf_thrand" and any(img is None for img in selected_images):
                threshold_count = sum(1 for img in selected_images if img is not None)
                print(f"Warning: Only {threshold_count} images found with confidence >= {args.conf_threshold}, padding with black images")
            
            # Calculate average confidence
            avg_conf = calculate_avg_confidence(selected_images)
            conf_suffix = f"_avgconf{avg_conf:.3f}" if avg_conf is not None else "_no_conf"
            
            # Create grid image
            print(f"Creating {grid_size}x{grid_size} grid...")
            grid_image = create_grid_image(selected_images, grid_size=grid_size)
            
            # Save the grid image
            output_filename = f"class_{idx}_{target_class}{conf_suffix}_grid.png"
            output_path = output_subdir / output_filename
            grid_image.save(output_path)
            
            print(f"Grid image saved as: {output_path}")
        
        print(f"All grid images saved in: {output_subdir}")

if __name__ == "__main__":
    main()

"""
Usage examples:

Required folder structure:
/path/to/your/data/
└── classwise_images/
    ├── class_00_airplane/
    │   ├── generated_000001_conf0.9234.png
    │   ├── generated_000002_conf0.8567.png
    │   └── ...
    ├── class_01_automobile/
    │   ├── generated_000001_conf0.9123.png
    │   └── ...
    ├── class_02_bird/
    └── ...

1) Basic usage (8x8 grid, conf_top mode):
   python visualize_grid.py --input_dir /path/to/your/data

2) Create 5x5 grid:
   python visualize_grid.py --input_dir /path/to/your/data --grid_size 5

3) Single class only:
   python visualize_grid.py --input_dir /path/to/your/data --class_idx 0

4) conf_thrand mode (random among confidence >= 0.92):
   python visualize_grid.py --input_dir /path/to/your/data --select conf_thrand --conf_threshold 0.92

5) Random selection mode:
   python visualize_grid.py --input_dir /path/to/your/data --select random

6) 10x10 grid, conf_thrand mode, threshold 0.8:
   python visualize_grid.py --input_dir /path/to/your/data --grid_size 10 --select conf_thrand --conf_threshold 0.8

Selection modes:
- conf_top: select by highest confidence (default)
- conf_thrand: random among images with confidence >= threshold
- random: random among all images

Output: grid images are saved under grid_images_* folder.
- all_classes_random_avgconf0.xxx_grid.png is created first.
- then class-wise grid images are generated.
"""
