import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import binary_erosion
from skimage import measure
from scipy.ndimage import binary_fill_holes

def create_keypoint_image_3d(img_size, keypoints, radius=2, intensity=1.0):
    """
    Create a 3D image with keypoints represented as spheres.
    
    Args:
        img_size: Tuple or array of (depth, height, width) for the 3D image
        keypoints: Array of keypoint coordinates, shape (N, 3) where each row is (z, y, x)
        radius: Radius of spheres to draw at keypoint locations
        intensity: Intensity value for the spheres (default 1.0)
    
    Returns:
        np.ndarray: 3D image with keypoints as spheres
    """
    # Create empty 3D image
    img = np.zeros(img_size, dtype=np.float32)
    
    # Convert keypoints to numpy if needed
    if hasattr(keypoints, 'detach'):
        kp = keypoints.detach().cpu().numpy()
    else:
        kp = np.array(keypoints)
    
    # Ensure keypoints are integers for indexing
    kp = np.round(kp).astype(int)
    
    # Create sphere mask
    sphere_size = int(2 * radius + 1)
    center = radius
    sphere_mask = np.zeros((sphere_size, sphere_size, sphere_size), dtype=bool)
    
    # Generate sphere coordinates
    z_coords, y_coords, x_coords = np.meshgrid(
        np.arange(sphere_size),
        np.arange(sphere_size), 
        np.arange(sphere_size),
        indexing='ij'
    )
    
    # Calculate distance from center and create sphere
    distances = np.sqrt((z_coords - center)**2 + (y_coords - center)**2 + (x_coords - center)**2)
    sphere_mask = distances <= radius
    
    # Place spheres at keypoint locations
    for i, point in enumerate(kp):
        z, y, x = point
        
        # Calculate bounds for placing the sphere
        z_start = max(0, z - radius)
        z_end = min(img_size[0], z + radius + 1)
        y_start = max(0, y - radius) 
        y_end = min(img_size[1], y + radius + 1)
        x_start = max(0, x - radius)
        x_end = min(img_size[2], x + radius + 1)
        
        # Calculate corresponding sphere mask bounds
        mask_z_start = max(0, radius - z)
        mask_z_end = mask_z_start + (z_end - z_start)
        mask_y_start = max(0, radius - y)
        mask_y_end = mask_y_start + (y_end - y_start) 
        mask_x_start = max(0, radius - x)
        mask_x_end = mask_x_start + (x_end - x_start)
        
        # Place sphere in image
        if (z_start < img_size[0] and y_start < img_size[1] and x_start < img_size[2] and
            z_end > 0 and y_end > 0 and x_end > 0):
            
            img[z_start:z_end, y_start:y_end, x_start:x_end] = np.maximum(
                img[z_start:z_end, y_start:y_end, x_start:x_end],
                sphere_mask[mask_z_start:mask_z_end, mask_y_start:mask_y_end, mask_x_start:mask_x_end] * intensity
            )
    
    return img


def visualize_keypoint_image_3d(keypoint_img, title="Keypoints 3D", axis=1, voxel_size=None):
    """
    Visualize 3D keypoint image using mean projection.
    
    Args:
        keypoint_img: 3D numpy array with keypoints as spheres
        title: Title for the plot
        axis: Axis along which to compute projection (0, 1, or 2)
        voxel_size: Voxel size for proper aspect ratio
    """
    # Compute mean projection
    projection = np.mean(keypoint_img, axis=axis)
    
    # Calculate aspect ratio
    axis_names = ['Z', 'Y', 'X']
    remaining_axes = [i for i in range(3) if i != axis]
    
    if voxel_size is not None:
        voxel_size = np.array(voxel_size)
        voxel_remaining = voxel_size[remaining_axes]
        aspect_ratio = voxel_remaining[0] / voxel_remaining[1]
    else:
        aspect_ratio = 1.0
    
    # Create plot
    fig, ax = plt.subplots(1, 1, figsize=(10, 8))
    
    im = ax.imshow(projection, cmap='hot', interpolation='nearest')
    ax.set_aspect(aspect_ratio)
    ax.set_title(f'{title} (Projection along {axis_names[axis]} axis)')
    ax.set_xlabel(f'Axis {remaining_axes[1]}')
    ax.set_ylabel(f'Axis {remaining_axes[0]}')
    
    plt.colorbar(im, ax=ax, shrink=0.8)
    plt.tight_layout()
    plt.show()


def visualize_dirlab(warped_moving, fixed, case_id, kp_fixed, kp_moved, kp_moving, axis=0, voxel_size=None):
    """
    Visualize keypoints overlaid on warped moving image with average projection.
    
    Args:
        warped_moving: Warped moving image tensor
        fixed: Fixed image tensor  
        case_id: Case identifier
        kp_fixed: fixed keypoints
        kp_moved: moved keypoints
        kp_moving: moving keypoints
        axis: Axis along which to compute average projection (0, 1, or 2)
        voxel_size: Voxel size for proper axis scaling
    """
    # Convert tensors to numpy if needed
    if hasattr(warped_moving, 'detach'):
        warped_img = warped_moving.detach().cpu().numpy()
    else:
        warped_img = warped_moving
    
    if hasattr(fixed, 'detach'):
        fixed_img = fixed.detach().cpu().numpy()
    else:
        fixed_img = fixed
    
    # Convert keypoints to numpy if needed
    if hasattr(kp_fixed, 'detach'):
        kp_fixed = kp_fixed.detach().cpu().numpy()
    else:
        kp_fixed = np.array(kp_fixed)
    
    if hasattr(kp_moved, 'detach'):
        kp_moved = kp_moved.detach().cpu().numpy()
    else:
        kp_moved = np.array(kp_moved)
    
    if hasattr(kp_moving, 'detach'):
        kp_moving = kp_moving.detach().cpu().numpy()
    else:
        kp_moving = np.array(kp_moving)
    
    # Compute average projection along specified axis
    warped_proj = np.mean(warped_img, axis=axis)
    fixed_proj = np.mean(fixed_img, axis=axis)
    
    # Project keypoints by removing the projection axis coordinate
    axis_names = ['Z', 'Y', 'X']
    remaining_axes = [i for i in range(3) if i != axis]
    
    kp_moving_2d = kp_moving[:, remaining_axes]
    kp_moved_2d = kp_moved[:, remaining_axes]
    kp_fixed_2d = kp_fixed[:, remaining_axes]
    
    # Calculate aspect ratio from voxel size if provided
    if voxel_size is not None:
        voxel_size = np.array(voxel_size)
        voxel_remaining = voxel_size[remaining_axes]
        aspect_ratio = voxel_remaining[0] / voxel_remaining[1]
    else:
        aspect_ratio = 1.0
    
    # Create figure with subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Plot warped moving image with keypoints
    ax1.imshow(warped_proj, cmap='gray', alpha=0.7)
    ax1.set_aspect(aspect_ratio)
    ax1.set_title(f'Warped Moving - Case {case_id} (Projection along {axis_names[axis]} axis)')
    
    # Plot keypoints with different markers and colors
    ax1.scatter(kp_moving_2d[:, 1], kp_moving_2d[:, 0], c='cyan', marker='o', s=10, label='Moving', alpha=0.8)
    ax1.scatter(kp_moved_2d[:, 1], kp_moved_2d[:, 0], c='yellow', marker='s', s=10, label='Warped Moving', alpha=0.8)
    ax1.scatter(kp_fixed_2d[:, 1], kp_fixed_2d[:, 0], c='magenta', marker='^', s=10, label='Fixed', alpha=0.8)
    
    # Draw lines between keypoints
    for i in range(len(kp_moving_2d)):
        # Moving to fixed (green)
        ax1.plot([kp_moving_2d[i, 1], kp_fixed_2d[i, 1]], 
                [kp_moving_2d[i, 0], kp_fixed_2d[i, 0]], 'g-', alpha=0.6, linewidth=2)
        
        # Moving to warped moving (blue)
        ax1.plot([kp_moving_2d[i, 1], kp_moved_2d[i, 1]], 
                [kp_moving_2d[i, 0], kp_moved_2d[i, 0]], 'b-', alpha=0.6, linewidth=2)
        
        # Warped moving to fixed (red)
        ax1.plot([kp_moved_2d[i, 1], kp_fixed_2d[i, 1]],
                [kp_moved_2d[i, 0], kp_fixed_2d[i, 0]], 'r-', alpha=0.6, linewidth=2)
    
    ax1.legend()
    ax1.set_xlabel(f'Axis {remaining_axes[1]}')
    ax1.set_ylabel(f'Axis {remaining_axes[0]}')
    
    # Plot fixed image with target keypoints for reference
    ax2.imshow(fixed_proj, cmap='gray', alpha=0.7)
    ax2.set_aspect(aspect_ratio)
    ax2.set_title(f'Fixed Image - Case {case_id} (Projection along {axis_names[axis]} axis)')
    ax2.scatter(kp_fixed_2d[:, 1], kp_fixed_2d[:, 0], c='magenta', marker='^', s=10, label='Fixed', alpha=0.8)
    ax2.legend()
    ax2.set_xlabel(f'Axis {remaining_axes[1]}')
    ax2.set_ylabel(f'Axis {remaining_axes[0]}')
    
    plt.tight_layout()
    plt.show()
    
    # Print line legend
    print("Line colors:")
    print("- Green: Moving -> Fixed")
    print("- Blue: Moving -> Warped Moving") 
    print("- Red: Warped Moving -> Fixed")


def visualize_dirlab_warp_fixed(warped_moving, fixed, case_id, kp_fixed, kp_fixed_warped, kp_moving, axis=0, voxel_size=None, visualize=True, save_path=None, error_threshold=None):
    """
    Visualize keypoints overlaid on warped moving image with average projection.
    
    Args:
        warped_moving: Warped moving image tensor
        fixed: Fixed image tensor  
        case_id: Case identifier
        kp_fixed: fixed keypoints
        kp_fixed_warped: fixed keypoints warped with the df
        kp_moving: moving keypoints
        axis: Axis along which to compute average projection (0, 1, or 2)
        voxel_size: Voxel size for proper axis scaling
        visualize: Whether to display the plot
        save_path: Path to save the plot
        error_threshold: If provided, only highlight keypoints with error > threshold (in mm)
    """
    # Convert tensors to numpy if needed
    if hasattr(warped_moving, 'detach'):
        warped_img = warped_moving.detach().cpu().numpy()
    else:
        warped_img = warped_moving
    
    if hasattr(fixed, 'detach'):
        fixed_img = fixed.detach().cpu().numpy()
    else:
        fixed_img = fixed
    
    # Convert keypoints to numpy if needed
    if hasattr(kp_fixed, 'detach'):
        kp_fixed = kp_fixed.detach().cpu().numpy()
    else:
        kp_fixed = np.array(kp_fixed)
    
    if hasattr(kp_fixed_warped, 'detach'):
        kp_fixed_warped = kp_fixed_warped.detach().cpu().numpy()
    else:
        kp_fixed_warped = np.array(kp_fixed_warped)
    
    if hasattr(kp_moving, 'detach'):
        kp_moving = kp_moving.detach().cpu().numpy()
    else:
        kp_moving = np.array(kp_moving)
    
    # Calculate keypoint errors if threshold is provided
    keypoint_mask = None
    if error_threshold is not None and voxel_size is not None:
        # Calculate Euclidean distance between fixed and warped keypoints
        distances = np.sqrt(np.sum((kp_fixed - kp_fixed_warped)**2, axis=1))
        # Convert to physical units (mm)
        if isinstance(voxel_size, (list, tuple, np.ndarray)):
            # Use average voxel size for distance conversion
            avg_voxel_size = np.mean(voxel_size)
        else:
            avg_voxel_size = voxel_size
        distances_mm = distances * avg_voxel_size
        # Create mask for keypoints above threshold
        keypoint_mask = distances_mm > error_threshold
        
        print(f"Highlighting {np.sum(keypoint_mask)}/{len(keypoint_mask)} keypoints with error > {error_threshold} mm")
    
    # Compute average projection along specified axis
    warped_proj = np.mean(warped_img, axis=axis)
    fixed_proj = np.mean(fixed_img, axis=axis)
    
    # Project keypoints by removing the projection axis coordinate
    axis_names = ['Z', 'Y', 'X']
    remaining_axes = [i for i in range(3) if i != axis]

    if axis_names[axis] == "Y":
        plane_type = "Coronal plane"
    elif axis_names[axis] == "X":
        plane_type = "Sagittal plane"
    else:
        plane_type = "Transverse plane"
    
    kp_moving_2d = kp_moving[:, remaining_axes]
    kp_fixed_warped_2d = kp_fixed_warped[:, remaining_axes]
    kp_fixed_2d = kp_fixed[:, remaining_axes]
    
    # Calculate aspect ratio from voxel size if provided
    if voxel_size is not None:
        voxel_size = np.array(voxel_size)
        voxel_remaining = voxel_size[remaining_axes]
        aspect_ratio = voxel_remaining[0] / voxel_remaining[1]
    else:
        aspect_ratio = 1.0
    
    # Create figure with subplots
    fig, ax1 = plt.subplots(1,1, figsize=(6,6))
    
    # Plot warped moving image with keypoints
    ax1.imshow(warped_proj, cmap='gray', alpha=0.7)
    ax1.set_aspect(aspect_ratio)
    
    # Plot all keypoints with different alpha values based on mask
    if keypoint_mask is not None:
        # Create inverse mask for low-error keypoints
        low_error_mask = ~keypoint_mask

        nomask_alpha = 0.3
        mask_alpha = 0.9
        
        # Plot low-error keypoints with low alpha
        if np.any(low_error_mask):
            ax1.scatter(kp_moving_2d[low_error_mask, 1], kp_moving_2d[low_error_mask, 0], 
                       c='cyan', marker='o', s=10, alpha=nomask_alpha)
            ax1.scatter(kp_fixed_warped_2d[low_error_mask, 1], kp_fixed_warped_2d[low_error_mask, 0], 
                       c='yellow', marker='s', s=10, alpha=nomask_alpha)
            ax1.scatter(kp_fixed_2d[low_error_mask, 1], kp_fixed_2d[low_error_mask, 0], 
                       c='magenta', marker='^', s=10, alpha=nomask_alpha)
        
        # Plot high-error keypoints with high alpha
        if np.any(keypoint_mask):
            ax1.scatter(kp_moving_2d[keypoint_mask, 1], kp_moving_2d[keypoint_mask, 0], 
                       c='cyan', marker='o', s=15, alpha=mask_alpha)
            ax1.scatter(kp_fixed_warped_2d[keypoint_mask, 1], kp_fixed_warped_2d[keypoint_mask, 0], 
                       c='yellow', marker='s', s=15, alpha=mask_alpha)
            ax1.scatter(kp_fixed_2d[keypoint_mask, 1], kp_fixed_2d[keypoint_mask, 0], 
                       c='magenta', marker='^', s=15, alpha=mask_alpha)

        # Draw lines with different alpha values
        for i in range(len(kp_moving_2d)):
            line_alpha = mask_alpha if keypoint_mask[i] else nomask_alpha
            line_width = 2 if keypoint_mask[i] else 1
            
            # Moving to fixed (green)
            ax1.plot([kp_moving_2d[i, 1], kp_fixed_2d[i, 1]], 
                    [kp_moving_2d[i, 0], kp_fixed_2d[i, 0]], 'g-', alpha=line_alpha, linewidth=line_width)
            
            # Fixed to warped fixed (blue)
            ax1.plot([kp_fixed_2d[i, 1], kp_fixed_warped_2d[i, 1]], 
                    [kp_fixed_2d[i, 0], kp_fixed_warped_2d[i, 0]], 'b-', alpha=line_alpha, linewidth=line_width)
            
            # Warped fixed to moving (red)
            ax1.plot([kp_fixed_warped_2d[i, 1], kp_moving_2d[i, 1]],
                    [kp_fixed_warped_2d[i, 0], kp_moving_2d[i, 0]], 'r-', alpha=line_alpha, linewidth=line_width)
    else:
        # Plot all keypoints with normal alpha if no mask
        ax1.scatter(kp_moving_2d[:, 1], kp_moving_2d[:, 0], c='cyan', marker='o', s=10, alpha=0.8)
        ax1.scatter(kp_fixed_warped_2d[:, 1], kp_fixed_warped_2d[:, 0], c='yellow', marker='s', s=10, alpha=0.8)
        ax1.scatter(kp_fixed_2d[:, 1], kp_fixed_2d[:, 0], c='magenta', marker='^', s=10, alpha=0.8)
        
        # Draw lines for all keypoints
        for i in range(len(kp_moving_2d)):
            ax1.plot([kp_moving_2d[i, 1], kp_fixed_2d[i, 1]], 
                    [kp_moving_2d[i, 0], kp_fixed_2d[i, 0]], 'g-', alpha=0.7, linewidth=2)
            ax1.plot([kp_fixed_2d[i, 1], kp_fixed_warped_2d[i, 1]], 
                    [kp_fixed_2d[i, 0], kp_fixed_warped_2d[i, 0]], 'b-', alpha=0.7, linewidth=2)
            ax1.plot([kp_fixed_warped_2d[i, 1], kp_moving_2d[i, 1]],
                    [kp_fixed_warped_2d[i, 0], kp_moving_2d[i, 0]], 'r-', alpha=0.7, linewidth=2)
    
    ax1.tick_params(axis='both', which='both', 
                bottom=False, top=False, left=False, 
                labelbottom=False, labelleft=False)
    
    plt.tight_layout()

    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white', pad_inches=0)
        print(f"Visualization saved to: {save_path}")
        
    if visualize:
        plt.show()
        # Print line legend
        print("Line colors:")
        print("- Green: Moving -> Fixed")
        print("- Blue: Fixed -> Warped Fixed") 
        print("- Red: Warped Fixed -> Moving")
    
    # Close figure to free memory
    plt.close()


def overlay_images_dirlab(fixed, warped_moving, case_id, axis=0, voxel_size=None, alpha=0.7, red_title='Fixed', green_title='Warped Moving', save_path=None, visualize=True):
    """
    Overlay fixed and warped moving images with different colors using mean projection.
    
    Args:
        fixed: Fixed image tensor (red)
        warped_moving: Warped moving image tensor (green)
        case_id: Case identifier
        axis: Axis along which to compute average projection (0, 1, or 2)
        voxel_size: Voxel size for proper axis scaling
        alpha: Alpha value for transparency (0-1)
        red_title: Title for the red channel image
        green_title: Title for the green channel image
        save_path: Optional path to save the plot as PNG (e.g., '/path/to/save/overlay.png')
    """
    # Convert tensors to numpy if needed
    if hasattr(fixed, 'detach'):
        fixed_img = fixed.detach().cpu().numpy()
    else:
        fixed_img = fixed

    if hasattr(warped_moving, 'detach'):
        warped_img = warped_moving.detach().cpu().numpy()
    else:
        warped_img = warped_moving
    
    # Compute average projection along specified axis
    fixed_proj = np.mean(fixed_img, axis=axis)
    warped_proj = np.mean(warped_img, axis=axis)
    
    # Normalize images to [0, 1] range
    fixed_proj = (fixed_proj - fixed_proj.min()) / (fixed_proj.max() - fixed_proj.min())
    warped_proj = (warped_proj - warped_proj.min()) / (warped_proj.max() - warped_proj.min())
    
    # Create RGB overlay: red for fixed, green for warped moving
    overlay = np.zeros((*fixed_proj.shape, 3))
    overlay[..., 0] = fixed_proj * alpha  # Red channel for fixed
    overlay[..., 1] = warped_proj * alpha  # Green channel for warped moving
    overlay[..., 2] = 0  # Blue channel empty
    
    # Calculate aspect ratio from voxel size if provided
    axis_names = ['Z', 'Y', 'X']
    remaining_axes = [i for i in range(3) if i != axis]

    if axis_names[axis] == "Y":
        plane_type = "Coronal plane"

    elif axis_names[axis] == "X":
        plane_type = "Sagittal plane"

    else:
        plane_type = "Transverse plane"
    
    if voxel_size is not None:
        voxel_size = np.array(voxel_size)
        voxel_remaining = voxel_size[remaining_axes]
        aspect_ratio = voxel_remaining[0] / voxel_remaining[1]
    else:
        aspect_ratio = 1.0
    
    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(10, 8))
    
    # Display overlay
    ax.imshow(overlay, aspect=aspect_ratio)
    ax.set_title(f'Overlay: {red_title} (Red) + {green_title} (Green) - Case {case_id}\n({plane_type})')
    # ax.set_xlabel(f'Axis {remaining_axes[1]}')
    # ax.set_ylabel(f'Axis {remaining_axes[0]}')
    
    ax.tick_params(axis='both', which='both', 
                bottom=False, top=False, left=False, 
                labelbottom=False, labelleft=False)

    

    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='red', alpha=alpha, label=red_title),
        Patch(facecolor='green', alpha=alpha, label=green_title),
        Patch(facecolor='yellow', alpha=alpha, label='Overlap (Good Registration)')
    ]
    # ax.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    
    # Save to file if path provided
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"Overlay saved to: {save_path}")
    
    if visualize:
        plt.show()


def collage_images_dirlab(moving, warped_moving, fixed, case_id, axis=0, voxel_size=None, save_path=None, visualize=True):
    """
    Create a collage showing moving, warped moving, and fixed images in a row using mean projection.
    
    Args:
        moving: Original moving image tensor
        warped_moving: Warped moving image tensor
        fixed: Fixed image tensor  
        case_id: Case identifier
        axis: Axis along which to compute average projection (0, 1, or 2)
        voxel_size: Voxel size for proper axis scaling
    """
    # Convert tensors to numpy if needed
    if hasattr(moving, 'detach'):
        moving_img = moving.detach().cpu().numpy()
    else:
        moving_img = moving
    
    if hasattr(warped_moving, 'detach'):
        warped_img = warped_moving.detach().cpu().numpy()
    else:
        warped_img = warped_moving
    
    if hasattr(fixed, 'detach'):
        fixed_img = fixed.detach().cpu().numpy()
    else:
        fixed_img = fixed
    
    # Compute average projection along specified axis
    moving_proj = np.mean(moving_img, axis=axis)
    warped_proj = np.mean(warped_img, axis=axis)
    fixed_proj = np.mean(fixed_img, axis=axis)
    
    # Calculate aspect ratio from voxel size if provided
    axis_names = ['Z', 'Y', 'X']
    remaining_axes = [i for i in range(3) if i != axis]
    
    if voxel_size is not None:
        voxel_size = np.array(voxel_size)
        voxel_remaining = voxel_size[remaining_axes]
        aspect_ratio = voxel_remaining[0] / voxel_remaining[1]
    else:
        aspect_ratio = 1.0
    
    # Create figure with three subplots in a row
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    
    # Plot moving image
    im1 = ax1.imshow(moving_proj, cmap='gray')
    ax1.set_aspect(aspect_ratio)
    ax1.set_title(f'Moving Image - Case {case_id}\n(Projection along {axis_names[axis]} axis)')
    ax1.set_xlabel(f'Axis {remaining_axes[1]}')
    ax1.set_ylabel(f'Axis {remaining_axes[0]}')
    plt.colorbar(im1, ax=ax1, shrink=0.8)
    
    # Plot warped moving image
    im2 = ax2.imshow(warped_proj, cmap='gray')
    ax2.set_aspect(aspect_ratio)
    ax2.set_title(f'Warped Moving Image - Case {case_id}\n(Projection along {axis_names[axis]} axis)')
    ax2.set_xlabel(f'Axis {remaining_axes[1]}')
    ax2.set_ylabel(f'Axis {remaining_axes[0]}')
    plt.colorbar(im2, ax=ax2, shrink=0.8)
    
    # Plot fixed image
    im3 = ax3.imshow(fixed_proj, cmap='gray')
    ax3.set_aspect(aspect_ratio)
    ax3.set_title(f'Fixed Image - Case {case_id}\n(Projection along {axis_names[axis]} axis)')
    ax3.set_xlabel(f'Axis {remaining_axes[1]}')
    ax3.set_ylabel(f'Axis {remaining_axes[0]}')
    plt.colorbar(im3, ax=ax3, shrink=0.8)
    
    plt.tight_layout()

    # Save to file if path provided
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"Overlay saved to: {save_path}")

    if visualize:
        plt.show()


def show_3d_projection(image_3d, title="3D Image Projection", voxel_size=None, axis=0, 
                      cmap='gray', figsize=(10, 8), save_path=None):
    """
    Show 3D image projection along specified axis with proper voxel size scaling.
    
    Args:
        image_3d: 3D numpy array or tensor to visualize
        title: Title for the plot
        voxel_size: Tuple/array of (z_size, y_size, x_size) in physical units (e.g., mm)
        axis: Axis along which to compute projection (0=Z, 1=Y, 2=X)
        cmap: Colormap for the image display
        figsize: Figure size as (width, height)
        save_path: Optional path to save the plot as PNG
    
    Returns:
        None
    """
    # Convert tensor to numpy if needed
    if hasattr(image_3d, 'detach'):
        img = image_3d.detach().cpu().numpy()
    else:
        img = np.array(image_3d)

    # Compute mean projection along specified axis
    projection = np.mean(img, axis=axis)
    
    # Calculate aspect ratio from voxel size
    axis_names = ['Z', 'Y', 'X']
    axis_labels = ['Depth', 'Height', 'Width']
    remaining_axes = [i for i in range(3) if i != axis]
    
    if voxel_size is not None:
        voxel_size = np.array(voxel_size)
        voxel_remaining = voxel_size[remaining_axes]
        aspect_ratio = voxel_remaining[0] / voxel_remaining[1]
        
        # Calculate physical extent for axis labels
        extent = [0, projection.shape[1] * voxel_remaining[1],
                 projection.shape[0] * voxel_remaining[0], 0]
        xlabel = f'{axis_names[remaining_axes[1]]} axis (physical units)'
        ylabel = f'{axis_names[remaining_axes[0]]} axis (physical units)'
    else:
        aspect_ratio = 1.0
        extent = None
        xlabel = f'{axis_names[remaining_axes[1]]} axis (voxels)'
        ylabel = f'{axis_names[remaining_axes[0]]} axis (voxels)'
    
    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    
    # Display projection
    im = ax.imshow(projection, cmap=cmap, aspect=aspect_ratio, extent=extent)
    
    # Set labels and title
    ax.set_title(f'{title}\n(Mean projection along {axis_names[axis]} axis)')
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    
    # Add colorbar
    plt.colorbar(im, ax=ax, shrink=0.8)
    
    # Add voxel size info to title if provided
    if voxel_size is not None:
        voxel_info = f"Voxel size: {voxel_size[0]:.2f} × {voxel_size[1]:.2f} × {voxel_size[2]:.2f}"
        ax.text(0.02, 0.98, voxel_info, transform=ax.transAxes, 
               bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
               verticalalignment='top', fontsize=9)
    
    plt.tight_layout()
    
    # Save to file if path provided
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"Projection saved to: {save_path}")
    
    plt.show()


def plot_accuracy_histogram(all_accuracies, case_id, save_path=None, visualize=True):
    """
    Plot histogram of landmark accuracies.
    
    Parameters:
    all_accuracies (numpy.ndarray): Array of accuracy values for each landmark
    case_id (int): Case identifier for the plot title
    save_path (str, optional): Path to save the histogram plot
    """
    plt.figure(figsize=(10, 6))
    plt.hist(all_accuracies, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    plt.xlabel('Target Registration Error (mm)')
    plt.ylabel('Frequency')
    plt.title(f'Histogram of Landmark Accuracies - Case {case_id}')
    plt.grid(True, alpha=0.3)
    
    # Add statistics text
    mean_acc = np.mean(all_accuracies)
    std_acc = np.std(all_accuracies)
    plt.axvline(mean_acc, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_acc:.2f} mm')
    plt.axvline(mean_acc + std_acc, color='orange', linestyle='--', linewidth=1, label=f'Mean + Std: {mean_acc + std_acc:.2f} mm')
    plt.axvline(mean_acc - std_acc, color='orange', linestyle='--', linewidth=1, label=f'Mean - Std: {mean_acc - std_acc:.2f} mm')
    plt.legend()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Histogram saved to: {save_path}")
    
    if visualize:
        plt.show()


def plot_accuracy_histogram_all(all_accuracies, case_id, save_path=None, visualize=True):
    """
    Plot histogram of landmark accuracies.
    
    Parameters:
    all_accuracies (numpy.ndarray): Array of accuracy values for each landmark
    case_id (int): Case identifier for the plot title
    save_path (str, optional): Path to save the histogram plot
    """
    plt.figure(figsize=(10, 6))
    plt.hist(all_accuracies, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    plt.xlabel('Target Registration Error (mm)')
    plt.ylabel('Frequency')
    plt.title(f'Histogram of Landmark Accuracies - {case_id}')
    plt.grid(True, alpha=0.3)
    
    # Add statistics text
    mean_acc = np.mean(all_accuracies)
    std_acc = np.std(all_accuracies)
    plt.axvline(mean_acc, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_acc:.2f} mm')
    plt.axvline(mean_acc + std_acc, color='orange', linestyle='--', linewidth=1, label=f'Mean + Std: {mean_acc + std_acc:.2f} mm')
    plt.axvline(mean_acc - std_acc, color='orange', linestyle='--', linewidth=1, label=f'Mean - Std: {mean_acc - std_acc:.2f} mm')
    plt.legend()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Histogram saved to: {save_path}")
    
    if visualize:
        plt.show()

def visualize_dirlab_warp_fixed_with_mask(
    warped_moving, fixed, mask, case_id,
    kp_fixed, kp_fixed_warped, kp_moving,
    axis=0, voxel_size=None, visualize=True, save_path=None
):
    """
    Visualize keypoints overlaid on warped moving image with average projection and mask contour.
    Args:
        warped_moving: Warped moving image tensor
        fixed: Fixed image tensor  
        mask: Binary mask array or tensor
        case_id: Case identifier
        kp_fixed: fixed keypoints
        kp_fixed_warped: fixed keypoints warped with the df
        kp_moving: moving keypoints
        axis: Axis along which to compute average projection (0, 1, or 2)
        voxel_size: Voxel size for proper axis scaling
    """
    # Convert tensors to numpy if needed
    def to_np(x):
        if hasattr(x, 'detach'):
            return x.detach().cpu().numpy()
        return np.array(x)

    warped_img = to_np(warped_moving)
    fixed_img = to_np(fixed)
    mask_img = to_np(mask)
    kp_fixed = to_np(kp_fixed)
    kp_fixed_warped = to_np(kp_fixed_warped)
    kp_moving = to_np(kp_moving)

    # Compute average projection along specified axis
    warped_proj = np.mean(warped_img, axis=axis)
    fixed_proj = np.mean(fixed_img, axis=axis)

    # mask_proj = np.mean(mask_img, axis=axis)
    # mask_bin = (mask_proj > 0).astype(np.uint8)  


    # eroded       = binary_erosion(mask_bin, np.ones((3,3)))
    # contour = mask_bin - eroded


    # Project keypoints by removing the projection axis coordinate
    axis_names = ['Z', 'Y', 'X']
    remaining_axes = [i for i in range(3) if i != axis]
    kp_moving_2d = kp_moving[:, remaining_axes]
    kp_fixed_warped_2d = kp_fixed_warped[:, remaining_axes]
    kp_fixed_2d = kp_fixed[:, remaining_axes]

    # Calculate aspect ratio from voxel size if provided
    if voxel_size is not None:
        voxel_size = np.array(voxel_size)
        voxel_remaining = voxel_size[remaining_axes]
        aspect_ratio = voxel_remaining[0] / voxel_remaining[1]
    else:
        aspect_ratio = 1.0

    # Determine plane type for title
    plane_type = {
        0: 'Transverse plane',
        1: 'Coronal plane',
        2: 'Sagittal plane'
    }[axis]

    # Create figure with subplots
    fig, ax1 = plt.subplots(1,1, figsize=(6,6))

    # Plot warped moving image with keypoints
    ax1.imshow(warped_proj, cmap='gray', alpha=0.7)
    ax1.set_aspect(aspect_ratio)
    ax1.set_title(f'Warped Moving - Case {case_id} ({plane_type})')
    ax1.scatter(kp_moving_2d[:, 1], kp_moving_2d[:, 0], c='cyan', marker='o', s=10, alpha=0.8)
    ax1.scatter(kp_fixed_warped_2d[:, 1], kp_fixed_warped_2d[:, 0], c='yellow', marker='s', s=10, alpha=0.8)
    ax1.scatter(kp_fixed_2d[:, 1], kp_fixed_2d[:, 0], c='magenta', marker='^', s=10, alpha=0.8)
    for i in range(len(kp_moving_2d)):
        ax1.plot([kp_moving_2d[i, 1], kp_fixed_2d[i, 1]], [kp_moving_2d[i, 0], kp_fixed_2d[i, 0]], 'g-', alpha=0.6, linewidth=2)
        ax1.plot([kp_fixed_2d[i, 1], kp_fixed_warped_2d[i, 1]], [kp_fixed_2d[i, 0], kp_fixed_warped_2d[i, 0]], 'b-', alpha=0.6, linewidth=2)
        ax1.plot([kp_fixed_warped_2d[i, 1], kp_moving_2d[i, 1]], [kp_fixed_warped_2d[i, 0], kp_moving_2d[i, 0]], 'r-', alpha=0.6, linewidth=2)

    # Overlay mask contour
    # ax1.contour(
    #     contour,
    #     levels=[0.5],
    #     colors='red',
    #     linewidths=1,
    #     alpha=0.8
    # )

    ax1.axis('off')

    # Plot fixed image with target keypoints for reference
    # ax2.imshow(fixed_proj, cmap='gray', alpha=0.7)
    # ax2.set_aspect(aspect_ratio)
    # ax2.set_title(f'Fixed Image - Case {case_id} ({plane_type})')
    # ax2.scatter(kp_fixed_2d[:, 1], kp_fixed_2d[:, 0], c='magenta', marker='^', s=10, alpha=0.8)
    # ax2.axis('off')

    mask_proj = np.max(mask_img, axis=axis).astype(np.uint8)
    mask_proj = binary_fill_holes(mask_proj)
    contours = measure.find_contours(mask_proj, level=0.5)
    for cnt in contours:
        ax1.plot(cnt[:, 1], cnt[:, 0], linewidth=1, color='red')


    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"Visualization saved to: {save_path}")
    if visualize:
        plt.show()
        print("Line colors: Moving->Fixed (green), Fixed->Warped Fixed (blue), Warped Fixed->Moving (red)")
    plt.close()

def apply_deformation_field(image, deformation_field):
    """
    Apply a deformation field to an image using linear interpolation.
    
    Args:
        image (np.ndarray): The input image to be deformed.
        deformation_field (np.ndarray): The deformation field, should have the same spatial dimensions as the image,
                                        with the last dimension being 2 (for x and y displacements).
    
    Returns:
        np.ndarray: The deformed image.
    """
    from scipy.ndimage import map_coordinates
    
    # Ensure image and deformation field are numpy arrays
    image = np.asarray(image)
    deformation_field = np.asarray(deformation_field)
    
    # Check dimensions
    if deformation_field.ndim != 4 or deformation_field.shape[-1] != 2:
        raise ValueError("Deformation field must be a 4D array with the last dimension of size 2.")
    
    # Create meshgrid of coordinates
    coords = np.meshgrid(
        np.arange(image.shape[0]),
        np.arange(image.shape[1]),
        indexing='ij'
    )
    
    # Add the deformation field to the coordinates
    displaced_coords = [coords[0] + deformation_field[..., 0], coords[1] + deformation_field[..., 1]]
    
    # Map the image values to the new coordinates
    deformed_image = map_coordinates(image, displaced_coords, order=1, mode='reflect')
    
    return deformed_image

def visualize_dirlab_warp_fixed_with_accuracy_mask(warped_moving, fixed, case_id, kp_fixed, kp_fixed_warped, kp_moving, keypoint_mask, axis=0, voxel_size=None, visualize=True, save_path=None):
    """
    Visualize keypoints overlaid on warped moving image with pre-computed accuracy mask.
    
    Args:
        warped_moving: Warped moving image tensor
        fixed: Fixed image tensor  
        case_id: Case identifier
        kp_fixed: fixed keypoints
        kp_fixed_warped: fixed keypoints warped with the df
        kp_moving: moving keypoints
        keypoint_mask: Boolean mask indicating which keypoints to highlight (True = high error)
        axis: Axis along which to compute average projection (0, 1, or 2)
        voxel_size: Voxel size for proper axis scaling
        visualize: Whether to display the plot
        save_path: Path to save the plot
    """
    # Convert tensors to numpy if needed
    if hasattr(warped_moving, 'detach'):
        warped_img = warped_moving.detach().cpu().numpy()
    else:
        warped_img = warped_moving
    
    if hasattr(fixed, 'detach'):
        fixed_img = fixed.detach().cpu().numpy()
    else:
        fixed_img = fixed
    
    # Convert keypoints to numpy if needed
    if hasattr(kp_fixed, 'detach'):
        kp_fixed = kp_fixed.detach().cpu().numpy()
    else:
        kp_fixed = np.array(kp_fixed)
    
    if hasattr(kp_fixed_warped, 'detach'):
        kp_fixed_warped = kp_fixed_warped.detach().cpu().numpy()
    else:
        kp_fixed_warped = np.array(kp_fixed_warped)
    
    if hasattr(kp_moving, 'detach'):
        kp_moving = kp_moving.detach().cpu().numpy()
    else:
        kp_moving = np.array(kp_moving)
    
    # Compute average projection along specified axis
    warped_proj = np.mean(warped_img, axis=axis)
    fixed_proj = np.mean(fixed_img, axis=axis)
    
    # Project keypoints by removing the projection axis coordinate
    axis_names = ['Z', 'Y', 'X']
    remaining_axes = [i for i in range(3) if i != axis]

    if axis_names[axis] == "Y":
        plane_type = "Coronal plane"
    elif axis_names[axis] == "X":
        plane_type = "Sagittal plane"
    else:
        plane_type = "Transverse plane"
    
    kp_moving_2d = kp_moving[:, remaining_axes]
    kp_fixed_warped_2d = kp_fixed_warped[:, remaining_axes]
    kp_fixed_2d = kp_fixed[:, remaining_axes]
    
    # Calculate aspect ratio from voxel size if provided
    if voxel_size is not None:
        voxel_size = np.array(voxel_size)
        voxel_remaining = voxel_size[remaining_axes]
        aspect_ratio = voxel_remaining[0] / voxel_remaining[1]
    else:
        aspect_ratio = 1.0
    
    # Create figure with subplots
    fig, ax1 = plt.subplots(1,1, figsize=(6,6))
    
    # Plot warped moving image with keypoints
    ax1.imshow(warped_proj, cmap='gray', alpha=0.7)
    ax1.set_aspect(aspect_ratio)
    
    # Plot all keypoints with different alpha values based on mask
    if keypoint_mask is not None:
        # Create inverse mask for low-error keypoints
        low_error_mask = ~keypoint_mask

        nomask_alpha = 0.2
        mask_alpha = 0.9
        
        # Plot low-error keypoints with low alpha
        if np.any(low_error_mask):
            ax1.scatter(kp_moving_2d[low_error_mask, 1], kp_moving_2d[low_error_mask, 0], 
                       c='cyan', marker='o', s=10, alpha=nomask_alpha)
            ax1.scatter(kp_fixed_warped_2d[low_error_mask, 1], kp_fixed_warped_2d[low_error_mask, 0], 
                       c='yellow', marker='s', s=10, alpha=nomask_alpha)
            ax1.scatter(kp_fixed_2d[low_error_mask, 1], kp_fixed_2d[low_error_mask, 0], 
                       c='magenta', marker='^', s=10, alpha=nomask_alpha)
        
        # Plot high-error keypoints with high alpha
        if np.any(keypoint_mask):
            ax1.scatter(kp_moving_2d[keypoint_mask, 1], kp_moving_2d[keypoint_mask, 0], 
                       c='cyan', marker='o', s=15, alpha=mask_alpha)
            ax1.scatter(kp_fixed_warped_2d[keypoint_mask, 1], kp_fixed_warped_2d[keypoint_mask, 0], 
                       c='yellow', marker='s', s=15, alpha=mask_alpha)
            ax1.scatter(kp_fixed_2d[keypoint_mask, 1], kp_fixed_2d[keypoint_mask, 0], 
                       c='magenta', marker='^', s=15, alpha=mask_alpha)

        # Draw lines with different alpha values
        for i in range(len(kp_moving_2d)):
            line_alpha = mask_alpha if keypoint_mask[i] else nomask_alpha
            line_width = 2 if keypoint_mask[i] else 1
            
            # Moving to fixed (green)
            ax1.plot([kp_moving_2d[i, 1], kp_fixed_2d[i, 1]], 
                    [kp_moving_2d[i, 0], kp_fixed_2d[i, 0]], 'g-', alpha=line_alpha, linewidth=line_width)
            
            # Fixed to warped fixed (blue)
            ax1.plot([kp_fixed_2d[i, 1], kp_fixed_warped_2d[i, 1]], 
                    [kp_fixed_2d[i, 0], kp_fixed_warped_2d[i, 0]], 'b-', alpha=line_alpha, linewidth=line_width)
            
            # Warped fixed to moving (red)
            ax1.plot([kp_fixed_warped_2d[i, 1], kp_moving_2d[i, 1]],
                    [kp_fixed_warped_2d[i, 0], kp_moving_2d[i, 0]], 'r-', alpha=line_alpha, linewidth=line_width)
    else:
        # Plot all keypoints with normal alpha if no mask
        ax1.scatter(kp_moving_2d[:, 1], kp_moving_2d[:, 0], c='cyan', marker='o', s=10, alpha=0.8)
        ax1.scatter(kp_fixed_warped_2d[:, 1], kp_fixed_warped_2d[:, 0], c='yellow', marker='s', s=10, alpha=0.8)
        ax1.scatter(kp_fixed_2d[:, 1], kp_fixed_2d[:, 0], c='magenta', marker='^', s=10, alpha=0.8)
        
        # Draw lines for all keypoints
        for i in range(len(kp_moving_2d)):
            ax1.plot([kp_moving_2d[i, 1], kp_fixed_2d[i, 1]], 
                    [kp_moving_2d[i, 0], kp_fixed_2d[i, 0]], 'g-', alpha=0.7, linewidth=2)
            ax1.plot([kp_fixed_2d[i, 1], kp_fixed_warped_2d[i, 1]], 
                    [kp_fixed_2d[i, 0], kp_fixed_warped_2d[i, 0]], 'b-', alpha=0.7, linewidth=2)
            ax1.plot([kp_fixed_warped_2d[i, 1], kp_moving_2d[i, 1]],
                    [kp_fixed_warped_2d[i, 0], kp_moving_2d[i, 0]], 'r-', alpha=0.7, linewidth=2)
    
    ax1.tick_params(axis='both', which='both', 
                bottom=False, top=False, left=False, 
                labelbottom=False, labelleft=False)
    
    plt.tight_layout(pad=0)

    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white', pad_inches=0)
        print(f"Visualization saved to: {save_path}")
        
    if visualize:
        plt.show()
        # Print line legend
        print("Line colors:")
        print("- Green: Moving -> Fixed")
        print("- Blue: Fixed -> Warped Fixed") 
        print("- Red: Warped Fixed -> Moving")
    
    # Close figure to free memory
    plt.close()
