#!/usr/bin/env python3
"""
TIPIR.py - Text Image Partition Isolation Recombination Algorithm
Core functionality: Create 64-block image with text and image recombination
Text: 48 blocks distributed across 4 regions for maximum isolation
Image: 16 blocks with letter labels A-P in fixed pattern
"""

import os
import random
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from typing import List, Tuple, Dict


def get_font(font_size):
    """Get font for text rendering"""
    font_paths = [
        "C:/Windows/Fonts/msyh.ttc",  # Microsoft YaHei
        "C:/Windows/Fonts/simhei.ttf",  # SimHei
        "C:/Windows/Fonts/simsun.ttc",  # SimSun
        "C:/Windows/Fonts/arial.ttf",  # Arial
    ]
    
    for font_path in font_paths:
        try:
            if os.path.exists(font_path):
                return ImageFont.truetype(font_path, font_size)
        except:
            continue
    
    # If all fail, use default font
    try:
        return ImageFont.load_default()
    except:
        return None


def create_image_16_parts_with_letters(image_path: str) -> Tuple[List[Tuple], str]:
    """
    Split image into 16 blocks (4x4) and reorganize with A-P letter labels
    
    Args:
        image_path: Original image path
        
    Returns:
        tuple: (image_blocks list, image_digital_list str)
               image_blocks: List of 16 image blocks, each element is (image_block, letter)
               image_digital_list: A-P recombination string, like "A-C-B-D-..."
    """
    try:
        # Read original image
        image = Image.open(image_path)
        image_np = np.array(image)
        
        # Get image dimensions
        height, width = image_np.shape[:2] if len(image_np.shape) == 2 else image_np.shape[:2]
        
        # Ensure image dimensions are multiples of 4 for easy splitting
        height = (height // 4) * 4
        width = (width // 4) * 4
        
        # Resize image
        image = image.resize((width, height), Image.Resampling.LANCZOS)
        image_np = np.array(image)
        
        # Calculate patch size
        patch_h = height // 4
        patch_w = width // 4
        
        # 1. Create position_number order (1-16, top to bottom, left to right)
        position_numbers = list(range(1, 17))
        
        # 2. Create letter sequence A-P (1-16 corresponds to A-P)
        letters = [chr(ord('A') + i) for i in range(16)]  # A, B, C, ..., P
        
        # 3. Randomly shuffle letter order
        shuffled_letters = letters.copy()
        random.shuffle(shuffled_letters)
        
        # 4. Create position to letter mapping
        position_to_letter = dict(zip(position_numbers, shuffled_letters))
        
        # 5. Split into 16 patches and add letter labels
        patches_with_letters = []
        for i in range(4):
            for j in range(4):
                # Calculate position_number (top to bottom, left to right)
                position_num = i * 4 + j + 1
                letter = position_to_letter[position_num]
                
                # Extract patch
                patch = image_np[i * patch_h:(i + 1) * patch_h,
                               j * patch_w:(j + 1) * patch_w]
                
                # Convert to PIL Image to add text
                patch_img = Image.fromarray(patch)
                
                # Add black letter to patch
                patch_with_text = add_black_letter_to_patch(patch_img, letter)
                
                patches_with_letters.append((patch_with_text, letter))
        
        # 6. Randomly shuffle patches order (maintain patch-letter correspondence)
        random.shuffle(patches_with_letters)
        
        # 7. Generate image_digital_list: record letter sequence in shuffled order
        image_digital_list = '-'.join([letter for _, letter in patches_with_letters])
        
        print(f"Image 16-block split successful, letter sequence: {image_digital_list}")
        
        return patches_with_letters, image_digital_list
        
    except Exception as e:
        print(f"Image 16-block split failed: {e}")
        return [], ""


def add_black_letter_to_patch(patch_img: Image.Image, letter: str) -> Image.Image:
    """
    Add black letter to image patch
    
    Args:
        patch_img: PIL Image object
        letter: Letter to add (A-P)
        
    Returns:
        PIL Image object with letter added
    """
    # Create drawing object
    draw = ImageDraw.Draw(patch_img)
    
    # Calculate font size (based on image size)
    width, height = patch_img.size
    font_size = min(width, height) // 4
    
    try:
        # Try to use system font
        font = ImageFont.truetype("arial.ttf", font_size)
    except (OSError, IOError):
        try:
            # If arial not found, try other fonts
            font = ImageFont.truetype("C:/Windows/Fonts/arial.ttf", font_size)
        except (OSError, IOError):
            # If all not found, use default font
            font = ImageFont.load_default()
    
    # Get letter text
    text = str(letter)
    
    # Calculate text position (center)
    bbox = draw.textbbox((0, 0), text, font=font)
    text_width = bbox[2] - bbox[0]
    text_height = bbox[3] - bbox[1]
    
    x = (width - text_width) // 2
    y = (height - text_height) // 2
    
    # Draw black letter
    draw.text((x, y), text, fill=(0, 0, 0), font=font)
    
    return patch_img


def define_partition_regions() -> Dict[str, List[Tuple[int, int]]]:
    """
    Define 4 partition region position coordinates
    
    Returns:
        dict: Coordinate lists for 4 regions
    """
    regions = {
        'A': [  # Region A: Column 2 + Column 4 upper half
            # Column 2 (column index 1)
            (0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1),
            # Column 4 upper half (column index 3, rows 0-3)
            (0, 3), (1, 3), (2, 3), (3, 3)
        ],
        'B': [  # Region B: Column 4 lower half + Column 6 upper half
            # Column 4 lower half (column index 3, rows 4-7)
            (4, 3), (5, 3), (6, 3), (7, 3),
            # Column 6 upper half (column index 5, rows 0-3)
            (0, 5), (1, 5), (2, 5), (3, 5),
            # Column 8 upper half (column index 7, rows 0-3)
            (0, 7), (1, 7), (2, 7), (3, 7)
        ],
        'C': [  # Region C: Column 6 lower half + Column 1 remaining + Column 3 remaining
            # Column 6 lower half (column index 5, rows 4-7)
            (4, 5), (5, 5), (6, 5), (7, 5),
            # Column 1 remaining (column index 0, odd rows)
            (1, 0), (3, 0), (5, 0), (7, 0),
            # Column 3 remaining (column index 2, even rows)
            (0, 2), (2, 2), (4, 2), (6, 2)
        ],
        'D': [  # Region D: Column 5 remaining + Column 7 remaining + Column 8 remaining
            # Column 5 remaining (column index 4, odd rows)
            (1, 4), (3, 4), (5, 4), (7, 4),
            # Column 7 remaining (column index 6, even rows)
            (0, 6), (2, 6), (4, 6), (6, 6),
            # Column 8 remaining (column index 7, rows 4-7)
            (4, 7), (5, 7), (6, 7), (7, 7)
        ]
    }
    
    return regions


def get_image_positions() -> List[Tuple[int, int]]:
    """
    Get image fixed position coordinates
    
    Returns:
        list: 16 image position coordinates
    """
    image_positions = [
        # Column 1: rows 0,2,4,6 (even rows)
        (0, 0), (2, 0), (4, 0), (6, 0),
        # Column 3: rows 1,3,5,7 (odd rows)
        (1, 2), (3, 2), (5, 2), (7, 2),
        # Column 5: rows 0,2,4,6 (even rows)
        (0, 4), (2, 4), (4, 4), (6, 4),
        # Column 7: rows 1,3,5,7 (odd rows)
        (1, 6), (3, 6), (5, 6), (7, 6)
    ]
    
    return image_positions


def merge_overflow_text_items(text_mapping_table: Dict[str, str], text_digital_list: str) -> Tuple[List[Tuple], str]:
    """
    Handle text merging logic for more than 48 blocks: pair first and last for merging
    
    Args:
        text_mapping_table: Text mapping table dict {key: value, ...}
        text_digital_list: Text digit string "1-2-3-..."
        
    Returns:
        tuple: (merged_text_blocks, updated_digital_list)
    """
    # Parse digital_list
    digital_numbers = [int(x) for x in text_digital_list.split('-')]
    
    if len(digital_numbers) <= 48:
        # Not exceeding 48, return directly, maintain original format
        merged_blocks = []
        for num in digital_numbers:
            key = str(num)
            value = text_mapping_table[key]
            merged_blocks.append((key, [value]))
        return merged_blocks, text_digital_list
    
    print(f"Start merging: total {len(digital_numbers)} text blocks, need to merge to within 48 blocks")
    
    merged_blocks = []
    merged_digital_parts = []
    digital_numbers_copy = digital_numbers.copy()
    
    # Continue merging until count <= 48
    while len(digital_numbers_copy) > 48:
        if len(digital_numbers_copy) <= 1:
            break
            
        # Take first and last for merging
        first_num = digital_numbers_copy.pop(0)
        last_num = digital_numbers_copy.pop(-1)
        
        # Get corresponding content
        first_content = text_mapping_table[str(first_num)]
        last_content = text_mapping_table[str(last_num)]
        
        # Create merged block, maintain original serial numbers
        merged_key = f"{first_num}+{last_num}"
        merged_content = [first_content, last_content]
        
        merged_blocks.append((merged_key, merged_content))
        merged_digital_parts.append(merged_key)
        
        print(f"Merged: {first_num}('{first_content}') + {last_num}('{last_content}')")
    
    # Add remaining unmerged blocks
    for num in digital_numbers_copy:
        key = str(num)
        value = text_mapping_table[key]
        merged_blocks.append((key, [value]))
        merged_digital_parts.append(key)
    
    # Update digital_list
    updated_digital_list = '-'.join(merged_digital_parts)
    
    print(f"Text block merging completed: {len(digital_numbers)} -> {len(merged_blocks)} blocks")
    print(f"Updated digital_list: {updated_digital_list}")
    
    return merged_blocks, updated_digital_list


def create_tipir_image(image_path: str, word_table: Dict[str, str], text_digital_list: str, dirname: str) -> Tuple[str, str]:
    """
    Create TIPIR algorithm processed image
    Main external interface function
    
    Args:
        image_path: Original image path
        word_table: Text mapping table from character-level splitting
        text_digital_list: Text digit list from character-level splitting  
        dirname: Dataset name for output filename
        
    Returns:
        tuple: (output_image_path, image_digital_list)
    """
    try:
        # 1. Process image: split into 16 blocks with letter labels
        image_blocks, image_digital_list = create_image_16_parts_with_letters(image_path)
        
        if not image_blocks or not image_digital_list:
            print(f"Image processing failed")
            return "", ""
        
        # 2. Generate output path
        output_dir = "../imgs"
        os.makedirs(output_dir, exist_ok=True)
        output_image_path = os.path.join(output_dir, f"{dirname}_tipir.png")
        
        # 3. Create 64-block partition isolation image
        create_partition_isolation_image(image_path, word_table, text_digital_list, image_blocks, output_image_path)
        
        print(f"TIPIR image created: {output_image_path}")
        print(f"Image digital list: {image_digital_list}")
        
        return output_image_path, image_digital_list
        
    except Exception as e:
        print(f"TIPIR image creation failed: {e}")
        return "", ""


def create_partition_isolation_image(img_path: str, text_mapping_table: Dict[str, str], text_digital_list: str, image_blocks: List[Tuple], output_path: str):
    """
    Create 64-block partition isolation image: 48 text blocks distributed across 4 regions + 16 image blocks in fixed pattern
    
    Args:
        img_path: Original image path (for size reference)
        text_mapping_table: Text mapping table dict
        text_digital_list: Text digit string
        image_blocks: Image block list [(image_block, letter), ...]
        output_path: Output image path
    """
    try:
        # Read original image for size reference
        original_image = Image.open(img_path)
        orig_width, orig_height = original_image.size
        
        # Ensure dimensions divisible by 8 (8x8=64 blocks)
        new_width = (orig_width // 8) * 8
        new_height = (orig_height // 8) * 8
        
        # Calculate block size
        block_width = new_width // 8
        block_height = new_height // 8
        
        # Create new image
        result_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
        
        # Get font
        font_size = min(block_width, block_height) // 6
        font = get_font(font_size)
        
        # Define regions and image positions
        regions = define_partition_regions()
        image_positions = get_image_positions()
        
        # Handle cases with more than 48 blocks
        text_blocks, updated_digital_list = merge_overflow_text_items(text_mapping_table, text_digital_list)
        
        print(f"Final text block count: {len(text_blocks)}")
        print(f"Updated digital_list: {updated_digital_list}")
        
        # Cyclic allocation strategy: evenly distribute text blocks to 4 regions according to digital_list order
        region_names = ['A', 'B', 'C', 'D']
        region_assignments = {name: [] for name in region_names}
        
        # Allocate text blocks according to updated digital_list order
        digital_parts = updated_digital_list.split('-')
        for i, digital_part in enumerate(digital_parts):
            region_name = region_names[i % 4]  # Cyclically allocate to 4 regions
            
            # Find corresponding text block
            for block_key, block_content in text_blocks:
                if block_key == digital_part:
                    region_assignments[region_name].append((block_key, block_content))
                    break
        
        # Implement even distribution: if some regions are too empty, reallocate
        total_blocks = len(text_blocks)
        target_per_region = total_blocks // 4
        extra_blocks = total_blocks % 4
        
        # Create even allocation targets
        region_targets = {}
        for i, region_name in enumerate(region_names):
            region_targets[region_name] = target_per_region + (1 if i < extra_blocks else 0)
        
        print(f"Even allocation targets: {region_targets}")
        
        # Reallocate to ensure evenness
        all_assigned_blocks = []
        for region_name in region_names:
            all_assigned_blocks.extend(region_assignments[region_name])
        
        # Reallocate evenly
        region_assignments = {name: [] for name in region_names}
        for i, (block_key, block_content) in enumerate(all_assigned_blocks):
            region_name = region_names[i % 4]
            region_assignments[region_name].append((block_key, block_content))
        
        print(f"Region allocation results:")
        for region_name, blocks in region_assignments.items():
            print(f"  Region {region_name}: {len(blocks)} text blocks")
        
        # Randomly allocate positions within each region
        for region_name, blocks in region_assignments.items():
            available_positions = regions[region_name].copy()
            random.shuffle(available_positions)  # Randomize within region
            
            for i, (block_key, block_content) in enumerate(blocks):
                if i < len(available_positions):
                    row, col = available_positions[i]
                    
                    # Calculate block position
                    left = col * block_width
                    top = row * block_height
                    
                    # Create text block
                    block_img = Image.new('RGB', (block_width, block_height), (255, 255, 255))
                    draw = ImageDraw.Draw(block_img)
                    
                    # Prepare text lines to display
                    lines = []
                    if isinstance(block_content, list):
                        # Merged multiple contents
                        if '+' in block_key:
                            # This is merged block, display all contents
                            parts = block_key.split('+')
                            for j, content in enumerate(block_content):
                                display_value = f"'{content}'"
                                text = f"{parts[j]}:{display_value}"
                                lines.append(text)
                        else:
                            # Single content but in list
                            for content in block_content:
                                display_value = f"'{content}'"
                                text = f"{block_key}:{display_value}"
                                lines.append(text)
                    else:
                        # Single content
                        display_value = f"'{block_content}'"
                        text = f"{block_key}:{display_value}"
                        lines.append(text)
                    
                    # Adjust font size to fit multi-line text
                    available_height = block_height - 4
                    line_height = font_size + 2
                    
                    # If too many lines, reduce font size
                    while len(lines) * line_height > available_height and font_size > 8:
                        font_size -= 1
                        line_height = font_size + 1
                        font = get_font(font_size)
                    
                    # Draw text
                    total_height = len(lines) * line_height
                    start_y = max(2, (block_height - total_height) // 2)
                    
                    for j, line in enumerate(lines):
                        y = start_y + j * line_height
                        
                        # If text too long, shorten display
                        if len(line) * (font_size // 2) > block_width * 0.9:
                            max_chars = int(block_width * 0.9 / (font_size // 2)) - 3
                            if max_chars > 0:
                                line = line[:max_chars] + "..."
                        
                        # Calculate center position
                        try:
                            bbox = draw.textbbox((0, 0), line, font=font)
                            text_width = bbox[2] - bbox[0]
                            x = max(2, (block_width - text_width) // 2)
                        except:
                            x = 2
                        
                        # Ensure not exceeding block boundaries
                        if y + line_height <= block_height:
                            draw.text((x, y), line, fill=(0, 0, 0), font=font)
                    
                    # Paste text block
                    result_image.paste(block_img, (left, top))
        
        # Randomly allocate image blocks to fixed positions
        image_positions_copy = image_positions.copy()
        random.shuffle(image_positions_copy)  # Randomly shuffle image positions
        
        for i, (image_block, letter) in enumerate(image_blocks):
            if i < len(image_positions_copy):
                row, col = image_positions_copy[i]
                
                # Calculate block position
                left = col * block_width
                top = row * block_height
                
                # Resize image block
                resized_block = image_block.resize((block_width, block_height), Image.Resampling.LANCZOS)
                
                # Paste image block
                result_image.paste(resized_block, (left, top))
        
        # Save result
        result_image.save(output_path)
        print(f"Partition isolation image saved: {output_path}")
        
    except Exception as e:
        print(f"Create partition isolation image failed: {e}")


if __name__ == "__main__":
    # Test example
    test_image_path = "test_image.jpg"  # Replace with actual image path
    test_word_table = {"1": "te", "2": "st", "3": " t", "4": "ex", "5": "t"}
    test_text_digital_list = "1-2-3-4-5"
    test_dirname = "test"
    
    # Call main interface
    output_path, image_digital_list = create_tipir_image(test_image_path, test_word_table, test_text_digital_list, test_dirname)
    print(f"Output: {output_path}, Image digits: {image_digital_list}")