import os
import cv2
import numpy as np
import argparse
from collections import defaultdict
#to get the top 3 largest masks of each category in the mask folder and visualize them in the image folder.
def create_visualizations(data_path):
    """
    Process image and mask data to create visualization results.

    For each category, this function finds the top 3 masks with largest areas,
    then draws their precise contours on the corresponding original images,
    and saves the results in a new 'vision' folder.

    Args:
        data_path (str): Root directory path containing 'images' and 'masks' subfolders.
    """
    # 1. Define paths for each folder
    images_dir = os.path.join(data_path, 'images')
    masks_dir = os.path.join(data_path, 'masks')
    vision_dir = os.path.join(data_path, 'vision')

    # Check if core folders exist
    if not os.path.isdir(images_dir) or not os.path.isdir(masks_dir):
        print(f"Error: 'images' and 'masks' folders not found in '{data_path}'.")
        return

    # 2. Create vision folder for storing results
    os.makedirs(vision_dir, exist_ok=True)
    print(f"Created output folder: '{vision_dir}'")

    # 3. Group mask files by category
    # Using defaultdict makes code cleaner, automatically creates empty list when accessing non-existent key
    masks_by_category = defaultdict(list)
    for filename in os.listdir(masks_dir):
        # Only process .png files
        if filename.endswith('.png'):
            try:
                # Filename format: "category_name_filename.png"
                # Use split('_', 1) to split only at first underscore to handle filenames with underscores
                category, rest = filename.split('_', 1)
                masks_by_category[category].append(filename)
            except ValueError:
                print(f"Warning: Skipping file with non-standard naming: {filename}")
                continue

    if not masks_by_category:
        print("Error: No .png mask files found in 'masks' folder.")
        return

    print(f"Found {len(masks_by_category)} categories.")

    # 4. Process each category one by one
    for category, mask_files in masks_by_category.items():
        print(f"--- Processing category '{category}' ---")
        
        mask_areas = []

        # 5. Find top 3 masks with largest areas in current category
        for mask_file in mask_files:
            mask_path = os.path.join(masks_dir, mask_file)
            # Read mask image in grayscale mode
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            
            if mask is None:
                print(f"Warning: Cannot read mask file {mask_file}, skipped.")
                continue

            # Count white pixels (non-zero pixels), i.e., mask area
            area = cv2.countNonZero(mask)
            
            # Store mask info with area
            mask_areas.append({
                'file': mask_file,
                'area': area
            })

        if not mask_areas:
            print(f"No valid mask files found in category '{category}'.")
            continue

        # Sort by area in descending order and take top 3
        mask_areas.sort(key=lambda x: x['area'], reverse=True)
        top_3_masks = mask_areas[:3]

        print(f"  - Found top 3 masks:")
        for i, mask_info in enumerate(top_3_masks, 1):
            print(f"    {i}. '{mask_info['file']}' (area: {mask_info['area']} pixels)")

        # 6. Generate and save visualization images for top 3 masks
        for i, mask_info in enumerate(top_3_masks, 1):
            try:
                # Parse corresponding original image filename from mask filename
                _cat, rest_of_name = mask_info['file'].split('_', 1)
                # Remove .png extension
                base_filename = os.path.splitext(rest_of_name)[0]

                # Build full paths for original image, mask, and final output image
                image_path = os.path.join(images_dir, f"{base_filename}.jpg")
                mask_path = os.path.join(masks_dir, mask_info['file'])
                output_path = os.path.join(vision_dir, f"{category}_top{i}.jpg")

                # Check if corresponding original image exists
                if not os.path.exists(image_path):
                    print(f"  - Error: Source image '{base_filename}.jpg' for mask '{mask_info['file']}' not found.")
                    continue

                # Read original image and mask
                image = cv2.imread(image_path)
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

                if image is None:
                     print(f"  - Error: Cannot read image file '{image_path}'.")
                     continue
                if mask is None:
                     print(f"  - Error: Cannot read mask file '{mask_path}'.")
                     continue

                # Find mask contours
                # cv2.RETR_EXTERNAL only detects outermost contours, meets requirements
                # cv2.CHAIN_APPROX_SIMPLE compresses horizontal, vertical and diagonal segments, keeping only endpoints, saves memory
                contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                
                # Draw contours on original image
                # Parameters: original image, contours, draw all contours(-1), color(B,G,R), line thickness
                cv2.drawContours(image, contours, -1, (0, 255, 0), 2) # Use green, 2-pixel thick line

                # Save final visualization image
                cv2.imwrite(output_path, image)
                print(f"  - Visualization result {i} saved to: '{output_path}'")

            except Exception as e:
                print(f"Unknown error occurred while processing mask {i} for category '{category}': {e}")

    print("\nAll processing completed!")


def main():
    """
    Main function for parsing command line arguments and starting processing
    """
    # Create argument parser
    parser = argparse.ArgumentParser(
        description="Create contour visualization images based on image and mask data.",
        # RawTextHelpFormatter keeps help information in original format
        formatter_class=argparse.RawTextHelpFormatter
    )
    # Add required positional argument
    parser.add_argument(
        'data_path',
        type=str,
        help="Dataset root directory path containing 'images' and 'masks' subfolders."
    )
    # Parse input arguments
    args = parser.parse_args()
    
    # Call core processing function
    create_visualizations(args.data_path)

if __name__ == '__main__':
    main()