import argparse
import os
import yaml
import numpy as np
import torch
import rasterio
from tqdm import tqdm
import glob
from PIL import Image
import matplotlib.pyplot as plt

from model.unet_plus_plus import UnetPlusPlus

def get_device(gpus_str):
    """Sets the device for PyTorch operations."""
    if torch.cuda.is_available() and gpus_str:
        gpu_ids = [int(gpu) for gpu in gpus_str.split(',')]
        device = torch.device(f"cuda:{gpu_ids[0]}")
        print(f"Using GPU: {gpu_ids[0]}")
    else:
        device = torch.device("cpu")
        print("CUDA not available or no GPUs specified. Using CPU.")
    return device

def calculate_iou(pred_mask, gt_mask):
    """
    Calculate Intersection over Union (IoU) between predicted and ground truth masks.
    
    Args:
        pred_mask (np.ndarray): Predicted binary mask
        gt_mask (np.ndarray): Ground truth binary mask
        
    Returns:
        float: IoU score
    """
    # Ensure both masks are binary
    pred_mask = (pred_mask > 0.5).astype(np.uint8)
    gt_mask = (gt_mask > 0.5).astype(np.uint8)
    
    # Calculate intersection and union
    intersection = np.logical_and(pred_mask, gt_mask).sum()
    union = np.logical_or(pred_mask, gt_mask).sum()
    
    # Handle case where union is 0
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    
    return intersection / union

def create_visualization(input_image, gt_mask, pred_mask, iou_score=None):
    """
    Create a visualization combining input image, ground truth mask, and prediction.
    
    Args:
        input_image (np.ndarray): Original input image (H, W, C) or (C, H, W)
        gt_mask (np.ndarray): Ground truth mask (H, W) or None
        pred_mask (np.ndarray): Predicted mask (H, W)
        iou_score (float): IoU score if available
        
    Returns:
        np.ndarray: Combined visualization image
    """
    # Ensure input_image is in (H, W, C) format
    if input_image.ndim == 3 and input_image.shape[0] < input_image.shape[-1]:
        # Convert from (C, H, W) to (H, W, C)
        input_image = np.transpose(input_image, (1, 2, 0))
    
    # If input has more than 3 channels, take first 3 for RGB visualization
    if input_image.shape[-1] > 3:
        input_bgr = input_image[:, :, :3]
    else:
        input_bgr = input_image
    input_rgb = input_bgr[:,:, [2, 1, 0]] * 3.0
    # Normalize input image to 0-1 range for visualization
    input_rgb = np.clip(input_rgb, 0, 1)
    
    # Always create 3 panels
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    titles = ['Input Image', 'Ground Truth', 'Prediction']
    
    # Plot input image
    axes[0].imshow(input_rgb)
    axes[0].set_title(titles[0], fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    # Plot ground truth (or empty if not available)
    if gt_mask is not None:
        axes[1].imshow(gt_mask, cmap='gray', vmin=0, vmax=1)
        axes[1].set_title(titles[1], fontsize=14, fontweight='bold')
    else:
        # Show empty/black image when no ground truth
        empty_mask = np.zeros_like(pred_mask)
        axes[1].imshow(empty_mask, cmap='gray', vmin=0, vmax=1)
        axes[1].set_title('Ground Truth\n(Not Available)', fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    # Plot prediction
    axes[2].imshow(pred_mask, cmap='gray', vmin=0, vmax=1)
    if iou_score is not None:
        axes[2].set_title(f'{titles[2]}\nIoU: {iou_score:.4f}', fontsize=14, fontweight='bold')
        
        # Add IoU text overlay on the prediction image
        axes[2].text(0.02, 0.98, f'IoU: {iou_score:.4f}', 
                    transform=axes[2].transAxes, 
                    fontsize=16, fontweight='bold',
                    verticalalignment='top',
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.8))
    else:
        axes[2].set_title(titles[2], fontsize=14, fontweight='bold')
    axes[2].axis('off')
    
    plt.tight_layout()
    
    # Convert plot to numpy array
    fig.canvas.draw()
    vis_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    vis_array = vis_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    
    plt.close(fig)
    return vis_array

def find_image_files(folder_path):
    """
    Find all image files in a folder.
    
    Args:
        folder_path (str): Path to the folder
        
    Returns:
        list: List of image file paths
    """
    image_extensions = ['*.tif', '*.tiff', '*.png', '*.jpg', '*.jpeg']
    image_files = []
    
    for ext in image_extensions:
        image_files.extend(glob.glob(os.path.join(folder_path, ext)))
        image_files.extend(glob.glob(os.path.join(folder_path, ext.upper())))
    
    return sorted(image_files)

def find_ground_truth_mask(image_path, labels_folder):
    """
    Find corresponding ground truth mask for an image.
    
    Args:
        image_path (str): Path to the input image
        labels_folder (str): Path to the labels folder
        
    Returns:
        str or None: Path to the ground truth mask if found
    """
    if not os.path.exists(labels_folder):
        print(f"    Labels folder not found: {labels_folder}")
        return None
    
    image_name = os.path.splitext(os.path.basename(image_path))[0]
    
    # Try various naming conventions and file extensions
    possible_extensions = ['.png', '.tif', '.tiff', '.jpg', '.jpeg']
    possible_patterns = [
        f"{image_name}_mask",
        f"{image_name}_label", 
        f"{image_name}_gt",
        f"{image_name}",
        f"mask_{image_name}",
        f"label_{image_name}",
        f"gt_{image_name}"
    ]
    
    for pattern in possible_patterns:
        for ext in possible_extensions:
            mask_path = os.path.join(labels_folder, f"{pattern}{ext}")
            if os.path.exists(mask_path):
                return mask_path
    
    # List available files in labels folder for debugging
    available_files = os.listdir(labels_folder)
    print(f"    No matching GT mask found for {image_name}")
    print(f"    Available files in labels folder: {available_files[:5]}...")  # Show first 5 files
    
    return None

def predict_large_image(model, image_path, patch_size, overlap, device):
    """
    Performs prediction on a large image using a sliding window approach.

    Args:
        model (torch.nn.Module): The trained model for prediction.
        image_path (str): Path to the input TIFF image.
        patch_size (int): The size of the patches to feed into the model.
        overlap (float): The fraction of overlap between adjacent patches (0 to 1).
        device (torch.device): The device to run inference on.

    Returns:
        tuple: A tuple containing:
            - np.ndarray: The final binary prediction mask.
            - dict: The rasterio profile of the original image for saving.
            - np.ndarray: The original image data for visualization.
    """
    # Open the large TIFF image
    with rasterio.open(image_path) as src:
        image = src.read()  # Reads as (bands, height, width)
        profile = src.profile

    n_bands, height, width = image.shape
    
    # Keep original image for visualization
    original_image = image.copy()
    
    # Use float32 for model input and apply normalization
    image = image.astype(np.float32)
    image /= 10000.0

    stride = int(patch_size * (1 - overlap))
    if stride <= 0:
        stride = 1

    # Calculate necessary padding
    pad_h = (stride - (height - patch_size) % stride) % stride
    pad_w = (stride - (width - patch_size) % stride) % stride

    # Pad the image on the right and bottom edges
    padded_image = np.pad(image, ((0, 0), (0, pad_h), (0, pad_w)), mode='constant')
    padded_height, padded_width = padded_image.shape[1], padded_image.shape[2]

    # Create placeholders for accumulating predictions and counts
    prediction_map = torch.zeros((1, 1, padded_height, padded_width), device=device, dtype=torch.float32)
    count_map = torch.zeros((1, 1, padded_height, padded_width), device=device, dtype=torch.float32)

    model.eval()
    with torch.no_grad():
        for y in tqdm(range(0, padded_height - patch_size + 1, stride), desc="Predicting Patches"):
            for x in range(0, padded_width - patch_size + 1, stride):
                # Extract patch
                patch = padded_image[:, y:y+patch_size, x:x+patch_size]
                
                # Convert to tensor and apply transformations consistent with dataset
                patch_tensor = torch.from_numpy(patch).float().to(device).unsqueeze(0) # (1, C, H, W)

                # Get model prediction
                output = model(patch_tensor)
                prediction = torch.sigmoid(output)

                # Accumulate predictions and counts
                prediction_map[:, :, y:y+patch_size, x:x+patch_size] += prediction
                count_map[:, :, y:y+patch_size, x:x+patch_size] += 1

    # Average the predictions in overlapping regions
    # Add a small epsilon to avoid division by zero
    averaged_prediction = prediction_map / (count_map + 1e-6)

    # Crop the prediction map back to the original image size
    final_prediction = averaged_prediction[:, :, :height, :width]

    # Convert to a binary mask
    binary_mask = (final_prediction > 0.5).squeeze().cpu().numpy().astype(np.uint8)

    return binary_mask, profile, original_image

def process_single_image(model, image_path, labels_folder, results_folder, patch_size, overlap, device):
    """
    Process a single image and save results.
    
    Args:
        model: The trained model
        image_path (str): Path to the input image
        labels_folder (str): Path to the labels folder (can be None)
        results_folder (str): Path to save results
        patch_size (int): Size of patches for prediction
        overlap (float): Overlap ratio
        device: Device for inference
        
    Returns:
        float or None: IoU score if ground truth is available
    """
    
    # Get prediction
    prediction_mask, profile, original_image = predict_large_image(
        model=model,
        image_path=image_path,
        patch_size=patch_size,
        overlap=overlap,
        device=device
    )
    
    # Check for ground truth
    gt_mask = None
    iou_score = None
    
    if labels_folder:
        gt_path = find_ground_truth_mask(image_path, labels_folder)
        if gt_path:
            try:
                # Try loading with different methods
                if gt_path.lower().endswith(('.tif', '.tiff')):
                    # Load TIFF with rasterio
                    with rasterio.open(gt_path) as src:
                        gt_mask = src.read(1)  # Read first band
                else:
                    # Load with PIL for other formats
                    gt_image = Image.open(gt_path)
                    gt_mask = np.array(gt_image)
                
                # Convert to binary if needed
                if gt_mask.ndim == 3:
                    gt_mask = gt_mask[:, :, 0]  # Take first channel
                
                # Normalize and threshold to binary
                if gt_mask.max() > 1:
                    gt_mask = (gt_mask > 127).astype(np.uint8)  # For 0-255 range
                else:
                    gt_mask = (gt_mask > 0.5).astype(np.uint8)  # For 0-1 range
                
                # Resize to match prediction if needed
                if gt_mask.shape != prediction_mask.shape:
                    gt_image_resized = Image.fromarray(gt_mask * 255).resize(
                        (prediction_mask.shape[1], prediction_mask.shape[0]), 
                        Image.NEAREST
                    )
                    gt_mask = (np.array(gt_image_resized) > 127).astype(np.uint8)
                
                # Calculate IoU
                iou_score = calculate_iou(prediction_mask, gt_mask)
                
            except Exception as e:
                print(f"    Error loading GT mask: {e}")
                gt_mask = None
    
    # Create output filename
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    if iou_score is not None:
        output_name = f"{base_name}_iou_{iou_score:.4f}.png"
    else:
        output_name = f"{base_name}_prediction.png"
    
    # Create and save visualization
    # Normalize original image for visualization
    vis_image = original_image.astype(np.float32) / 10000.0
    vis_image = np.clip(vis_image, 0, 1)
    
    visualization = create_visualization(vis_image, gt_mask, prediction_mask, iou_score)
    
    vis_output_path = os.path.join(results_folder, output_name)
    vis_image_pil = Image.fromarray(visualization)
    vis_image_pil.save(vis_output_path)
    
    
    return iou_score

def main():
    parser = argparse.ArgumentParser(description="Mangrove Segmentation Prediction")
    parser.add_argument(
        '--log_dir', type=str, required=True,
        help='Path to the directory containing trained model logs and config.'
    )
    parser.add_argument(
        '--input_path', type=str, required=True,
        help='Path to the input image file or folder containing images for prediction.'
    )
    parser.add_argument(
        '--overlap', type=float, default=0.25,
        help='Overlap ratio between patches (0 to 1). Default: 0.25'
    )
    parser.add_argument(
        '--gpus', type=str, default='0',
        help='GPU IDs to use for inference, comma-separated (e.g., "0"). Uses CPU if empty or not available.'
    )
    args = parser.parse_args()

    # --- Load Config ---
    config_path = os.path.join(args.log_dir, 'config.yaml')
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found at {config_path}")
    
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    # --- Setup Device ---
    device = get_device(args.gpus)
    
    # --- Load Model ---
    model = UnetPlusPlus(
        in_channels=config['model']['args']['in_channels'],
        classes=config['model']['args']['classes'],
        encoder_name=config['model']['args']['encoder_name']
    )
    
    checkpoint_path = os.path.join(args.log_dir, 'weights', 'last.pt')
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
        
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    model.to(device)
    print("Model loaded successfully.")

    # --- Setup Results Directory ---
    results_folder = os.path.join(args.log_dir, 'results')
    os.makedirs(results_folder, exist_ok=True)

    # --- Determine Input Type and Process ---
    patch_size = 256
    
    if os.path.isfile(args.input_path):

        # Check for labels folder (assuming it's in the same parent directory)
        parent_dir = os.path.dirname(args.input_path)
        labels_folder = os.path.join(parent_dir, 'labels')
        
        iou_score = process_single_image(
            model=model,
            image_path=args.input_path,
            labels_folder=labels_folder if os.path.exists(labels_folder) else None,
            results_folder=results_folder,
            patch_size=patch_size,
            overlap=args.overlap,
            device=device
        )
        
        print(f"Processing complete.")
        
    elif os.path.isdir(args.input_path):
        # Folder processing
        print(f"Processing folder: {args.input_path}")
        
        # Find all image files
        image_files = find_image_files(args.input_path)
        if not image_files:
            print(f"No image files found in {args.input_path}")
            return
        
        print(f"Found {len(image_files)} image files")
        
        # Check for labels folder
        parent_dir = os.path.dirname(args.input_path.rstrip('/'))
        folder_name = os.path.basename(args.input_path.rstrip('/'))
        
        # Try different possible labels folder locations
        possible_labels_folders = [
            os.path.join(parent_dir, 'labels'),
            os.path.join(os.path.dirname(parent_dir), 'labels'),
            args.input_path.replace('images', 'labels'),
        ]
        
        labels_folder = None
        for possible_folder in possible_labels_folders:
            if os.path.exists(possible_folder):
                labels_folder = possible_folder
                print(f"Found labels folder: {labels_folder}")
                break
        
        if labels_folder is None:
            print("No labels folder found. Proceeding without IoU calculation.")
        
        # Process all images
        iou_scores = []
        
        for image_path in tqdm(image_files, desc="Processing images"):
            iou_score = process_single_image(
                model=model,
                image_path=image_path,
                labels_folder=labels_folder,
                results_folder=results_folder,
                patch_size=patch_size,
                overlap=args.overlap,
                device=device
            )
            
            if iou_score is not None:
                iou_scores.append(iou_score)
        
        # Print summary
        if iou_scores:
            mean_iou = np.mean(iou_scores)
            print(f"\nSummary:")
            print(f"  Processed {len(image_files)} images")
            print(f"  Average IoU: {mean_iou:.4f}")
            print(f"  IoU range: {min(iou_scores):.4f} - {max(iou_scores):.4f}")
        else:
            print(f"\nProcessed {len(image_files)} images (no IoU calculated)")
        
        print(f"All results saved to: {results_folder}")
        
    else:
        raise ValueError(f"Input path {args.input_path} is neither a file nor a directory")

if __name__ == '__main__':
    main() 