#!/usr/bin/env python3
"""
Image Puzzle Dataset Builder Script
Supports 2x2, 4x4, 6x6 puzzle shuffling modes
Dataset sizes are 20, 40, 70 respectively
"""

import os
import json
import random
import shutil
from PIL import Image
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple, Any
from collections import defaultdict

class PuzzleDatasetBuilder:
    def __init__(self, images_root: str, prompt_root: str, output_root: str):
        self.images_root = Path(images_root)
        self.prompt_root = Path(prompt_root)
        self.output_root = Path(output_root)
        
        # Shuffle configuration: (grid_size, count)
        self.puzzle_configs = {
            "2x2": (2, 20),
            "4x4": (4, 40), 
            "6x6": (6, 70)
        }
        
        # Load prompt data
        self.prompt_data = self._load_prompt_data()
        
        # Track used images to ensure no duplicates
        self.used_images = set()
        
        # Track used subfolders (categories) to ensure each category is used only once
        self.used_categories = set()
        
    def _load_prompt_data(self) -> Dict[str, Dict]:
        """Load all prompt data"""
        prompt_data = {}
        
        for category in ["animal", "flower", "landmark"]:
            prompt_file = self.prompt_root / f"{category}.json"
            if prompt_file.exists():
                with open(prompt_file, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    for item in data:
                        prompt_data[item['category']] = item
        
        return prompt_data
    
    def _get_available_images(self) -> Dict[str, List[Path]]:
        """Get all available image files, filter images with resolution ≥1024 and aspect ratio <1.2"""
        available_images = defaultdict(list)
        
        for category_dir in ["animal", "flower", "landmark"]:
            category_path = self.images_root / category_dir
            if not category_path.exists():
                continue
                
            for subdir in category_path.iterdir():
                if subdir.is_dir():
                    # Extract category name from folder name, remove prefix numbers
                    folder_name = subdir.name.split('_', 1)[1] if '_' in subdir.name else subdir.name
                    # Replace underscores with spaces to match category format in JSON
                    category_name = folder_name.replace('_', ' ')
                    
                    for img_file in subdir.iterdir():
                        if img_file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.webp']:
                            # Check image resolution and aspect ratio
                            try:
                                with Image.open(img_file) as img:
                                    width, height = img.size
                                    min_size = min(width, height)
                                    max_size = max(width, height)
                                    aspect_ratio = max_size / min_size
                                    
                                    # Check resolution ≥1024 and aspect ratio <1.2
                                    if min_size >= 1024 and aspect_ratio < 1.2:
                                        available_images[category_name].append(img_file)
                            except Exception as e:
                                # Skip if unable to read image
                                continue
        
        return available_images
    
    def _prepare_square_image(self, image: Image.Image, grid_size: int) -> Image.Image:
        """
        Process image into a square, resize based on the shorter side, then adjust to size divisible by grid_size
        
        Args:
            image: Input image
            grid_size: Grid size (2, 4, 6)
            
        Returns:
            Processed square image
        """
        # Resize to square based on the shorter side
        min_size = min(image.size)
        
        # Adjust to appropriate size (divisible by grid_size)
        target_size = (min_size // grid_size) * grid_size
        # Ensure each puzzle piece has a reasonable minimum size
        if target_size < grid_size * 64:  
            target_size = grid_size * 64
        
        # Directly resize to square, no cropping
        image = image.resize((target_size, target_size), Image.LANCZOS)
        return image
    
    def _ensure_shuffled(self, original_order: List[int]) -> List[int]:
        """
        Ensure the generated order is different from the original order (i.e., ensure it's shuffled)
        
        Args:
            original_order: Original order list
            
        Returns:
            Ensured shuffled order list
        """
        shuffled = original_order.copy()
        max_attempts = 100  # Prevent infinite loop
        
        for attempt in range(max_attempts):
            random.shuffle(shuffled)
            
            # Check if same as original order
            if shuffled != original_order:
                # Additional check: ensure at least a certain proportion of positions changed
                different_positions = sum(1 for i, val in enumerate(shuffled) if val != original_order[i])
                if different_positions >= max(2, len(original_order) // 2):
                    return shuffled
        
        # If random shuffling fails (extremely low probability), use manual method to ensure shuffling
        # At least swap two adjacent elements
        if len(original_order) >= 2:
            shuffled = original_order.copy()
            # Swap first two elements
            shuffled[0], shuffled[1] = shuffled[1], shuffled[0]
            # If still the same, swap other elements
            if len(original_order) >= 4:
                shuffled[2], shuffled[3] = shuffled[3], shuffled[2]
        
        return shuffled
    
    def _verify_shuffled(self, shuffled_order: List[int]) -> bool:
        """
        Verify if the image is truly shuffled
        
        Args:
            shuffled_order: Shuffled order
            
        Returns:
            True if shuffled, False if not shuffled
        """
        original_order = list(range(len(shuffled_order)))
        
        # Check if same as original order
        if shuffled_order == original_order:
            return False
        
        # Check how many positions changed
        different_positions = sum(1 for i, val in enumerate(shuffled_order) if val != original_order[i])
        
        # At least half of the positions must change to be considered truly shuffled
        min_different = max(2, len(shuffled_order) // 2)
        
        return different_positions >= min_different
    
    def _shuffle_image(self, image: Image.Image, grid_size: int) -> Tuple[Image.Image, List[int]]:
        """
        Shuffle the image
        Note: Input image should already be a preprocessed square
        Returns: (shuffled image, original position index list)
        """
        target_size = image.size[0]  # Input is already square, width equals height
        
        # Split image
        piece_size = target_size // grid_size
        pieces = []
        
        for i in range(grid_size):
            for j in range(grid_size):
                left = j * piece_size
                top = i * piece_size
                right = left + piece_size
                bottom = top + piece_size
                
                piece = image.crop((left, top, right, bottom))
                pieces.append(piece)
        
        # Generate shuffled order, ensure it's definitely shuffled
        original_order = list(range(len(pieces)))
        shuffled_order = self._ensure_shuffled(original_order)
        
        # Verify if truly shuffled
        if not self._verify_shuffled(shuffled_order):
            raise ValueError(f"Image shuffling failed! Grid size: {grid_size}x{grid_size}")
        
        # Create shuffled image
        shuffled_image = Image.new('RGB', (target_size, target_size))
        
        for idx, piece_idx in enumerate(shuffled_order):
            i = idx // grid_size
            j = idx % grid_size
            
            left = j * piece_size
            top = i * piece_size
            
            shuffled_image.paste(pieces[piece_idx], (left, top))
        
        return shuffled_image, shuffled_order
    
    def _create_output_structure(self):
        """Create unified output directory structure"""
        (self.output_root / "question").mkdir(parents=True, exist_ok=True)
        (self.output_root / "solution").mkdir(parents=True, exist_ok=True)
    
    def build_dataset(self):
        """Build the dataset"""
        print("Starting to build image puzzle-restoration dataset...")
        
        # Create output directory structure
        self._create_output_structure()
        
        # Get all available images
        available_images = self._get_available_images()
        
        # Debug info: Display loaded data statistics
        print(f"\nData loading statistics:")
        print(f"  Prompt data: {len(self.prompt_data)} categories")
        print(f"  Image data: {len(available_images)} categories")
        print(f"  Matched categories: {len([cat for cat in available_images.keys() if cat in self.prompt_data])}")
        
        # Count total valid images
        total_valid_images = sum(len(imgs) for imgs in available_images.values())
        print(f"  Total valid images (resolution≥1024 and aspect ratio<1.2): {total_valid_images}")
        
        # Find mismatched categories
        image_categories = set(available_images.keys())
        prompt_categories = set(self.prompt_data.keys())
        
        images_without_prompts = image_categories - prompt_categories
        prompts_without_images = prompt_categories - image_categories
        
        if images_without_prompts:
            print(f"\nCategories with images but no prompts ({len(images_without_prompts)}):")
            for cat in sorted(list(images_without_prompts)[:10]):  # Only show first 10
                print(f"  - {cat}")
            if len(images_without_prompts) > 10:
                print(f"  ... and {len(images_without_prompts) - 10} more")
        
        if prompts_without_images:
            print(f"\nCategories with prompts but no images ({len(prompts_without_images)}):")
            for cat in sorted(list(prompts_without_images)[:10]):  # Only show first 10
                print(f"  - {cat}")
            if len(prompts_without_images) > 10:
                print(f"  ... and {len(prompts_without_images) - 10} more")
        
        # Show some example categories
        if available_images:
            print(f"\nImage category examples: {list(available_images.keys())[:3]}")
        if self.prompt_data:
            print(f"Prompt category examples: {list(self.prompt_data.keys())[:3]}")
        
        # Unified output directories
        question_dir = self.output_root / "question"
        solution_dir = self.output_root / "solution"
        
        # Collect all prompt information
        all_prompts = []
        
        # Global image ID counter
        global_image_id = 1
        
        # Generate data for each puzzle type
        for puzzle_type, (grid_size, count) in self.puzzle_configs.items():
            print(f"\nGenerating {puzzle_type} dataset (need {count} images)...")
            
            # Randomly select images
            selected_images = []
            attempts = 0
            max_attempts = count * 10  # Prevent infinite loop
            
            while len(selected_images) < count and attempts < max_attempts:
                attempts += 1
                
                # Get available categories (exclude used categories)
                available_categories = [cat for cat in available_images.keys() 
                                      if cat not in self.used_categories and available_images[cat]]
                
                if not available_categories:
                    print(f"Warning: No more available categories for {puzzle_type}")
                    break
                
                # Randomly select an unused category
                category_name = random.choice(available_categories)
                
                # Check if corresponding prompt data exists
                if category_name not in self.prompt_data:
                    continue
                
                # Randomly select one image from this category
                img_path = random.choice(available_images[category_name])
                
                selected_images.append((img_path, category_name))
                self.used_images.add(str(img_path))
                # Mark this category as used
                self.used_categories.add(category_name)
            
            if len(selected_images) < count:
                print(f"Warning: {puzzle_type} only found {len(selected_images)} available images, less than needed {count}")
            
            # Process each selected image
            for idx, (img_path, category_name) in enumerate(selected_images):
                try:
                    # Load image
                    with Image.open(img_path) as img:
                        # Convert to RGB
                        if img.mode != 'RGB':
                            img = img.convert('RGB')
                        
                        # Step 1: Resize image to square (based on shorter side)
                        square_img = self._prepare_square_image(img, grid_size)
                        
                        # Step 2: Shuffle the square image
                        shuffled_img, shuffled_order = self._shuffle_image(square_img, grid_size)
                        
                        # Verify shuffling effect
                        original_order = list(range(grid_size * grid_size))
                        different_positions = sum(1 for i, val in enumerate(shuffled_order) if val != original_order[i])
                        shuffle_ratio = different_positions / len(shuffled_order)
                        
                        # Save file using global ID
                        filename = f"{global_image_id:03d}.png"
                        
                        # Save shuffled image (question) - square
                        shuffled_img.save(question_dir / filename)
                        
                        # Save original image (solution) - square, resized but not shuffled
                        square_img.save(solution_dir / filename)
                        
                        # Record prompt info, keep only necessary fields
                        prompt_info = {
                            'category': self.prompt_data[category_name]['category'],
                            'prompt': self.prompt_data[category_name]['prompt'],
                            'grid_size': f"{grid_size}x{grid_size}",
                            'image_id': global_image_id,
                            'shuffled_order': shuffled_order,
                            'shuffle_ratio': round(shuffle_ratio, 2)
                        }
                        all_prompts.append(prompt_info)
                        
                        print(f"  Processing complete: {filename} ({category_name}) - {grid_size}x{grid_size} - Square {square_img.size[0]}x{square_img.size[1]} - Shuffle ratio: {shuffle_ratio:.1%}")
                        
                        # Increment global ID
                        global_image_id += 1
                        
                except Exception as e:
                    print(f"  Error processing image {img_path}: {e}")
                    continue
            
            print(f"  {puzzle_type} dataset generation complete, total {len(selected_images)} images")
        
        # Save unified prompt.json file
        prompt_file = self.output_root / "prompt.json"
        with open(prompt_file, 'w', encoding='utf-8') as f:
            json.dump(all_prompts, f, ensure_ascii=False, indent=2)
        
        print(f"\nUnified prompt.json file saved, total {len(all_prompts)} images")
        
        print(f"\nDataset construction complete! Output directory: {self.output_root}")
        print(f"Total unique images used: {len(self.used_images)}")
        print(f"Total different categories used: {len(self.used_categories)}")


def main():
    # Configure paths
    images_root = "/home/hlihg/HDD/lhx/bench/similarity/images"
    prompt_root = "/home/hlihg/HDD/lhx/bench/benchmark/prompt"
    output_root = "/home/hlihg/HDD/lhx/bench/pullze_dataset"
    
    # Create dataset builder
    builder = PuzzleDatasetBuilder(images_root, prompt_root, output_root)
    
    # Build dataset
    builder.build_dataset()


if __name__ == "__main__":
    main()
