import argparse
import json
from pathlib import Path

import cv2
import numpy as np
from PIL import Image, ImageDraw

Image.MAX_IMAGE_PIXELS = None

VISUALIZATION_MAX_WIDTH = 2048 # Max width of output visualization
FOREGROUND_COLOR = (65, 105, 225, 140)

def parse_args():
    parser = argparse.ArgumentParser(
        description="Generate visualization of detected foreground patches."
    )
    parser.add_argument("--sample_dir", required=True, type=str, help="Path to the sample directory.")
    parser.add_argument(
        "--output_name", 
        type=str, 
        default="foreground_segmentation.jpg",
        help="Name of the output visualization file."
    )
    parser.add_argument(
        "--overlay_alpha",
        type=int,
        default=140,
        help="Alpha transparency for foreground overlay (0-255)."
    )
    parser.add_argument(
        "--overlay_color",
        type=str,
        default="65,105,225",
        help="RGB color for foreground overlay (comma-separated, e.g., '255,0,0' for red)."
    )
    return parser.parse_args()

def create_foreground_visualization(sample_dir: Path, output_name: str, overlay_alpha: int, overlay_color: tuple):
    sample_id = sample_dir.name
    print(f"\n{'='*40}\nCreating foreground visualization for sample: {sample_id}")

    data_dir = sample_dir / "data"
    image_path = data_dir / "histology.tif"
    foreground_json_path = sample_dir / "foreground_patches.json"
    
    if not image_path.exists():
        print(f"  Error: Histology image not found at '{image_path}'. Aborting.")
        return
        
    if not foreground_json_path.exists():
        print(f"  Error: Foreground patches JSON not found at '{foreground_json_path}'.")
        print(f"  Run detect_foreground.py first to generate the foreground data.")
        return

    wsi = None
    try:
        # Load foreground metadata
        with open(foreground_json_path, 'r') as f:
            foreground_metadata = json.load(f)
        
        if not foreground_metadata:
            print("  Warning: No foreground patches found in the JSON file.")
            return
            
        print(f"  Loaded {len(foreground_metadata)} foreground patches from JSON.")

        # Load WSI and create visualization
        wsi = Image.open(image_path)
        img_w, img_h = wsi.size
        print(f"  - Image: {img_w}x{img_h}")

        print("Creating segmentation visualization...")
        downscale = max(1, img_w // VISUALIZATION_MAX_WIDTH)
        thumb_size = (img_w // downscale, img_h // downscale)
        base_thumb = wsi.resize(thumb_size, Image.Resampling.LANCZOS).convert("RGBA")
        overlay = Image.new("RGBA", base_thumb.size, (0, 0, 0, 0))
        draw = ImageDraw.Draw(overlay)
        
        # Create overlay color with alpha
        overlay_rgba = (*overlay_color, overlay_alpha)
        
        for meta in foreground_metadata:
            coords = meta["coordinates"]
            thumb_coords = [
                coords['left'] // downscale, 
                coords['top'] // downscale, 
                coords['right'] // downscale, 
                coords['bottom'] // downscale
            ]
            draw.rectangle(thumb_coords, fill=overlay_rgba)
        
        composite_img = Image.alpha_composite(base_thumb, overlay).convert("RGB")
        visualization_path = sample_dir / output_name
        composite_img.save(visualization_path, quality=95)
        print(f"Saved segmentation visualization to: {visualization_path}")

    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        import traceback
        traceback.print_exc()
    finally:
        if wsi:
            wsi.close()
        print(f"\nFinished creating visualization for {sample_id}.")
        print("="*40)

if __name__ == "__main__":
    args = parse_args()
    sample_directory = Path(args.sample_dir)
    
    if not sample_directory.is_dir():
        print(f"ERROR: Directory not found: '{sample_directory}'")
        exit(1)
    
    # Parse overlay color
    try:
        overlay_color = tuple(map(int, args.overlay_color.split(',')))
        if len(overlay_color) != 3 or any(c < 0 or c > 255 for c in overlay_color):
            raise ValueError("Invalid color format")
    except ValueError:
        print(f"ERROR: Invalid overlay color '{args.overlay_color}'. Use format 'R,G,B' with values 0-255.")
        exit(1)
    
    create_foreground_visualization(
        sample_directory, 
        args.output_name, 
        args.overlay_alpha, 
        overlay_color
    )