# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
# either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

"""
Compute 2.5D FID using distributed GPU processing.

SHELL Usage Example:
-------------------
    #!/bin/bash

    export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
    NUM_GPUS=7

    torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct.py \
        --model_name "radimagenet_resnet50" \
        --real_dataset_root "path/to/datasetA" \
        --real_filelist "path/to/filelistA.txt" \
        --real_features_dir "datasetA" \
        --synth_dataset_root "path/to/datasetB" \
        --synth_filelist "path/to/filelistB.txt" \
        --synth_features_dir "datasetB" \
        --enable_center_slices_ratio 0.4 \
        --enable_padding True \
        --enable_center_cropping True \
        --enable_resampling_spacing "1.0x1.0x1.0" \
        --ignore_existing True \
        --num_images 100 \
        --output_root "./features/features-512x512x512" \
        --target_shape "512x512x512"

This script loads two datasets (real vs. synthetic) in 3D medical format (NIfTI)
and extracts feature maps via a 2.5D approach. It then computes the Frechet
Inception Distance (FID) across three orthogonal planes. Data parallelism
is implemented using torch.distributed with an NCCL backend.

Function Arguments (main):
--------------------------
    real_dataset_root (str):
        Root folder for the real dataset.

    real_filelist (str):
        Text file listing 3D images for the real dataset.

    real_features_dir (str):
        Subdirectory (under `output_root`) in which to store feature files
        extracted from the real dataset.

    synth_dataset_root (str):
        Root folder for the synthetic dataset.

    synth_filelist (str):
        Text file listing 3D images for the synthetic dataset.

    synth_features_dir (str):
        Subdirectory (under `output_root`) in which to store feature files
        extracted from the synthetic dataset.

    enable_center_slices_ratio (float or None):
        - If not None, only slices around the specified center ratio will be used
          (analogous to "enable_center_slices=True" with that ratio).
        - If None, no center-slice selection is performed
          (analogous to "enable_center_slices=False").

    enable_padding (bool):
        Whether to pad images to `target_shape`.

    enable_center_cropping (bool):
        Whether to center-crop images to `target_shape`.

    enable_resampling_spacing (str or None):
        - If not None, resample images to the specified voxel spacing (e.g. "1.0x1.0x1.0")
          (analogous to "enable_resampling=True" with that spacing).
        - If None, resampling is skipped
          (analogous to "enable_resampling=False").

    ignore_existing (bool):
        If True, ignore any existing .pt feature files and force re-extraction.

    model_name (str):
        Model identifier. Typically "radimagenet_resnet50" or "squeezenet1_1".

    num_images (int):
        Max number of images to process from each dataset (truncate if more are present).

    output_root (str):
        Folder where extracted .pt feature files, logs, and results are saved.

    target_shape (str):
        Target shape as "XxYxZ" for padding, cropping, or resampling operations.
"""


from __future__ import annotations

import os
import sys
import torch
import monai
import re
import torch.distributed as dist
import torch.nn.functional as F

from datetime import timedelta
from pathlib import Path
from monai.metrics.fid import FIDMetric
from monai.transforms import Compose

import logging

# ------------------------------------------------------------------------------
# Create logger
# ------------------------------------------------------------------------------
logger = logging.getLogger("fid_2-5d_ct")
if not logger.handlers:
    # Configure logger only if it has no handlers (avoid reconfiguring in multi-rank scenarios)
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger.setLevel(logging.INFO)


def drop_empty_slice(slices, empty_threshold: float):
    """
    Decide which 2D slices to keep by checking if their maximum intensity
    is below a certain threshold.

    Args:
        slices (tuple or list of Tensors): Each element is (B, C, H, W).
        empty_threshold (float): If the slice's maximum value is below this threshold,
            it is considered "empty".

    Returns:
        list[bool]: A list of booleans indicating for each slice whether to keep it.
    """
    outputs = []
    n_drop = 0
    for s in slices:
        largest_unique = torch.max(torch.unique(s))
        if largest_unique < empty_threshold:
            outputs.append(False)
            n_drop += 1
        else:
            outputs.append(True)

    logger.info(f"Empty slice drop rate {round((n_drop/len(slices))*100,1)}%")
    return outputs


def subtract_mean(x: torch.Tensor) -> torch.Tensor:
    """
    Subtract per-channel means (ImageNet-like: [0.406, 0.456, 0.485])
    from the input 4D or 5D tensor. Expects channels in the first dimension
    after the batch dimension: (B, C, H, W) or (B, C, H, W, D).
    """
    mean = [0.406, 0.456, 0.485]
    x[:, 0, ...] -= mean[0]
    x[:, 1, ...] -= mean[1]
    x[:, 2, ...] -= mean[2]
    return x


def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
    """
    Average out the spatial dimensions of a tensor, preserving or removing them
    according to `keepdim`. This is used to produce a 1D feature vector
    out of a feature map.

    Args:
        x (torch.Tensor): Input tensor (B, C, H, W, ...) or (B, C, H, W).
        keepdim (bool): Whether to keep dimension or not after averaging.

    Returns:
        torch.Tensor: Tensor with reduced spatial dimensions.
    """
    dim = len(x.shape)
    # 2D -> no average
    if dim == 2:
        return x
    # 3D -> average over last dim
    if dim == 3:
        return x.mean([2], keepdim=keepdim)
    # 4D -> average over H,W
    if dim == 4:
        return x.mean([2, 3], keepdim=keepdim)
    # 5D -> average over H,W,D
    if dim == 5:
        return x.mean([2, 3, 4], keepdim=keepdim)
    return x


def medicalnet_intensity_normalisation(volume: torch.Tensor) -> torch.Tensor:
    """
    Intensity normalization approach from MedicalNet:
    (volume - mean) / (std + 1e-5) across spatial dims.
    Expects (B, C, H, W) or (B, C, H, W, D).
    """
    dim = len(volume.shape)
    if dim == 4:
        mean = volume.mean([2, 3], keepdim=True)
        std = volume.std([2, 3], keepdim=True)
    elif dim == 5:
        mean = volume.mean([2, 3, 4], keepdim=True)
        std = volume.std([2, 3, 4], keepdim=True)
    else:
        return volume
    return (volume - mean) / (std + 1e-5)


def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = False) -> torch.Tensor:
    """
    Intensity normalization for radimagenet_resnet. Optionally normalizes each 2D slice individually.

    Args:
        volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D).
        norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean.
    """
    logger.info(f"norm2d: {norm2d}")
    dim = len(volume.shape)
    # If norm2d is True, only meaningful for 4D data (B, C, H, W):
    if dim == 4 and norm2d:
        max2d, _ = torch.max(volume, dim=2, keepdim=True)
        max2d, _ = torch.max(max2d, dim=3, keepdim=True)
        min2d, _ = torch.min(volume, dim=2, keepdim=True)
        min2d, _ = torch.min(min2d, dim=3, keepdim=True)
        # Scale each slice to 0..1
        volume = (volume - min2d) / (max2d - min2d + 1e-10)
        # Subtract channel mean
        return subtract_mean(volume)
    elif dim == 4:
        # 4D but no per-slice normalization
        max3d = torch.max(volume)
        min3d = torch.min(volume)
        volume = (volume - min3d) / (max3d - min3d + 1e-10)
        return subtract_mean(volume)
    # Fallback for e.g. 5D data is simply a min-max over entire volume
    if dim == 5:
        maxval = torch.max(volume)
        minval = torch.min(volume)
        volume = (volume - minval) / (maxval - minval + 1e-10)
        return subtract_mean(volume)
    return volume


def get_features_2p5d(
    image: torch.Tensor,
    mask: torch.Tensor,
    feature_network: torch.nn.Module,
    center_slices: bool = False,
    center_slices_ratio: float = 1.0,
    sample_every_k: int = 1,
    xy_only: bool = True,
    drop_empty: bool = False,
    empty_threshold: float = -700,
    use_min_max_normalization: bool = True,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
    """
    Extract 2.5D features from a 3D image by slicing it along XY, YZ, ZX planes.

    Args:
        image (torch.Tensor): Input 5D tensor in shape (B, C, H, W, D).
        feature_network (torch.nn.Module): Model that processes 2D slices (C,H,W).
        center_slices (bool): Whether to slice only the center portion of each axis.
        center_slices_ratio (float): Ratio of slices to keep in the center if `center_slices` is True.
        sample_every_k (int): Downsampling factor along each axis when slicing.
        xy_only (bool): If True, return only the XY-plane features.
        drop_empty (bool): Drop slices that are deemed "empty" below `empty_threshold`.
        empty_threshold (float): Threshold to decide emptiness of slices.
        use_min_max_normalization (bool): If True, use min-max normalization; otherwise just subtract mean.

    Returns:
        tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features).
    """
    logger.info(f"center_slices: {center_slices}, ratio: {center_slices_ratio}")

    # If there's only 1 channel, replicate to 3 channels
    if image.shape[1] == 1:
        image = image.repeat(1, 3, 1, 1, 1)

    # Convert from 'RGB'→(R,G,B) to (B,G,R)
    image = image[:, [2, 1, 0], ...]

    B, C, H, W, D = image.size()
    with torch.no_grad():
        # ---------------------- XY-plane slicing along D ----------------------
        if center_slices:
            start_d = int((1.0 - center_slices_ratio) / 2.0 * D)
            end_d = int((1.0 + center_slices_ratio) / 2.0 * D)
            slices = torch.unbind(image[:, :, :, :, start_d:end_d:sample_every_k], dim=-1)
            mask_slices = torch.unbind(mask[:, :, :, :, start_d:end_d:sample_every_k], dim=-1)
        else:
            slices = torch.unbind(image, dim=-1)
            mask_slices = torch.unbind(mask, dim=-1)

        images_2d = torch.cat(slices, dim=0)
        masks = torch.cat(mask_slices, dim=0)
        # Find empty masks if needed
        if drop_empty:
            mapping_index = drop_empty_slice(masks, empty_threshold)
        else:
            mapping_index = slice(None) # Keep all slices

        if use_min_max_normalization:
            images_2d = radimagenet_intensity_normalisation(images_2d)
        else:
            images_2d = subtract_mean(images_2d)
        images_2d = images_2d[mapping_index]

        feature_image_xy = feature_network.forward(images_2d)
        feature_image_xy = spatial_average(feature_image_xy, keepdim=False)
        if xy_only:
            return feature_image_xy, None, None

        # ---------------------- YZ-plane slicing along H ----------------------
        if center_slices:
            start_h = int((1.0 - center_slices_ratio) / 2.0 * H)
            end_h = int((1.0 + center_slices_ratio) / 2.0 * H)
            slices = torch.unbind(image[:, :, start_h:end_h:sample_every_k, :, :], dim=2)
            mask_slices = torch.unbind(mask[:, :, start_h:end_h:sample_every_k, :, :], dim=2)
        else:
            slices = torch.unbind(image, dim=2)
            mask_slices = torch.unbind(mask, dim=2)

        images_2d = torch.cat(slices, dim=0)
        masks = torch.cat(mask_slices, dim=0)
        # Find empty masks if needed
        if drop_empty:
            mapping_index = drop_empty_slice(masks, empty_threshold)
        else:
            mapping_index = slice(None) # Keep all slices

        if use_min_max_normalization:
            images_2d = radimagenet_intensity_normalisation(images_2d)
        else:
            images_2d = subtract_mean(images_2d)
        images_2d = images_2d[mapping_index]

        feature_image_yz = feature_network.forward(images_2d)
        feature_image_yz = spatial_average(feature_image_yz, keepdim=False)

        # ---------------------- ZX-plane slicing along W ----------------------
        if center_slices:
            start_w = int((1.0 - center_slices_ratio) / 2.0 * W)
            end_w = int((1.0 + center_slices_ratio) / 2.0 * W)
            slices = torch.unbind(image[:, :, :, start_w:end_w:sample_every_k, :], dim=3)
            mask_slices = torch.unbind(mask[:, :, :, start_w:end_w:sample_every_k, :], dim=3)
        else:
            slices = torch.unbind(image, dim=3)
            mask_slices = torch.unbind(mask, dim=3)

        images_2d = torch.cat(slices, dim=0)
        masks = torch.cat(mask_slices, dim=0)
        # Find empty masks if needed
        if drop_empty:
            mapping_index = drop_empty_slice(masks, empty_threshold)
        else:
            mapping_index = slice(None) # Keep all slices

        if use_min_max_normalization:
            images_2d = radimagenet_intensity_normalisation(images_2d)
        else:
            images_2d = subtract_mean(images_2d)
        images_2d = images_2d[mapping_index]

        feature_image_zx = feature_network.forward(images_2d)
        feature_image_zx = spatial_average(feature_image_zx, keepdim=False)

    return feature_image_xy, feature_image_yz, feature_image_zx


def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float = 0.0) -> torch.Tensor:
    """
    Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size.

    Args:
        tensor (torch.Tensor): The feature tensor to pad.
        max_size (int): Desired size along the first dimension.
        padding_value (float): Value to fill during padding.

    Returns:
        torch.Tensor: Padded tensor matching `max_size` along dim=0.
    """
    pad_size = [0, 0] * (len(tensor.shape) - 1) + [0, max_size - tensor.shape[0]]
    return F.pad(tensor, pad_size, "constant", padding_value)
