"""
Contains functions for loading and processing segmentations
"""

import numpy as np
import torch
import torch.nn.functional as F
from scipy.ndimage import convolve
import nibabel as nib
import pydicom
import pydicom_seg

import logging
log = logging.getLogger(__name__)


def load_nib_segmentation(path, target_shape):
    # Returns mask of shape [H, W, D]
    segmentation = nib.load(path)
    seg_mask = torch.from_numpy(segmentation.get_fdata()).float()
    seg_mask = torch.clip(seg_mask, 0, 1)
    if seg_mask.shape == target_shape:
        return seg_mask
    seg_mask = torch.nn.functional.interpolate(
        seg_mask.unsqueeze(0).unsqueeze(0),
        size=target_shape,
        mode='nearest'
    )[0, 0]
    return seg_mask

def load_nib_multilabel_segmentation(path, target_shape):
    # Returns mask of shape [H, W, D]
    segmentation = nib.load(path)
    seg_mask = torch.from_numpy(segmentation.get_fdata()).float()
    if seg_mask.shape == target_shape:
        return seg_mask
    seg_mask = torch.nn.functional.interpolate(
        seg_mask.unsqueeze(0).unsqueeze(0),
        size=target_shape,
        mode='nearest'
    )[0, 0]
    return seg_mask


def load_dicom_segmentation(path, target_shape, clip=True):
    dcm = pydicom.dcmread(path)

    reader = pydicom_seg.MultiClassReader()
    result = reader.read(dcm) # [D, W, H]
    seg_mask = torch.from_numpy(result.data).float()
    seg_mask = seg_mask.permute(2, 1, 0) # [H, W, D]

    seg_mask = torch.nn.functional.interpolate(
        seg_mask.unsqueeze(0).unsqueeze(0),
        size=target_shape,
        mode='nearest'
    )[0, 0]

    if clip:
        seg_mask = torch.clip(seg_mask, 0, 1)

    return seg_mask


def make_circular_kernel(radius, dims: int = 3):
    shape = [2 * radius + 1] * dims
    return np.where(
        ((np.indices(shape).T.reshape(*shape, dims) - radius) ** 2).sum(3) > radius**2,
        0,
        1,
    )


def enlarge_segmentation_torch(
    segmentation: torch.Tensor, radius: int, eps=1e-8
) -> torch.Tensor:
    dims = segmentation.dim()
    kernel = torch.tensor(make_circular_kernel(radius, dims=dims), dtype=torch.float)
    kernel = kernel.reshape(1, 1, *kernel.shape)
    segmentation = segmentation.to(torch.float)
    segmentation = segmentation.reshape(1, 1, *segmentation.shape)
    out = F.conv3d(segmentation, kernel, padding=radius)
    return (out.reshape(*out.shape[2:]) > eps).to(torch.uint8)


def enlarge_segmentation_numpy(
    segmentation: np.ndarray, radius: int, eps=1e-8
) -> np.ndarray:
    kernel = make_circular_kernel(radius)
    out = convolve(segmentation.astype(np.float32), kernel, mode="constant", cval=0.0)
    return (out > eps).astype(np.uint8)


def find_segment_centers(
    segmentation: torch.Tensor, min_size: int = 1
) -> torch.Tensor:
    centers = []
    for i in range(1, segmentation.max().item() + 1):
        mask = (segmentation == i) * 1
        if mask.sum() < min_size:
            continue
        center = find_region_center(mask)
        center = torch.tensor(center, dtype=torch.int)
        centers.append(center)
    if len(centers) == 0:
        return None
    return torch.stack(centers, dim=0)


def map_segments(
    segmentation: torch.Tensor,
    index_map: dict[int, int],
):
    n_segments = segmentation.max().detach().cpu().item()
    n_segments = int(n_segments)
    result = torch.zeros_like(segmentation, dtype=torch.uint8)
    for i in range(1, n_segments + 1):
        if i in index_map:
            result[segmentation == i] = index_map[i]
    return result, len(index_map)



def find_region_center(binary_mask: torch.Tensor) -> tuple[int, ...]:
    """
    Find the center of a binary region in a nD mask.
    Returns the center coordinates as a tuple.
    """
    # Find the center of the region
    region_center = binary_mask.nonzero().float().mean(dim=0).round().int()
    return tuple(region_center.tolist())


def _find_starting_point(center: int, size: int, max_size: int) -> int:
    if center < size // 2:
        x = 0
    elif center > max_size - size // 2:
        x = max_size - size
    else:
        x = int(center - size // 2)
    return x


@torch.no_grad()
def minimal_bounding_box(mask: torch.Tensor) -> list[slice]:
    """
    Find the minimal bounding box for a binary mask.
    Returns a tuple of slices for each dimension.
    """
    assert mask.dim() > 0, "Mask must have at least one dimension"
    mask = mask.detach()
    slices = []
    for axis in range(mask.dim()):
        if mask.shape[axis] == 0:
            raise ValueError(f"Mask dimension {axis} is empty, cannot compute bounding box.")
        remaining_dims = [dim for dim in range(mask.dim()) if dim != axis]
        sums = mask.sum(dim=remaining_dims)
        non_zero_indices = torch.nonzero(sums, as_tuple=False)
        minc, maxc = non_zero_indices.min().cpu().item(), non_zero_indices.max().cpu().item()
        slices.append(slice(minc, maxc + 1))
    return slices


def roi_center_crop_slices(binary_mask: torch.Tensor, size = 64) -> list[slice]:
    """
    Find the center crop of a binary mask based on its center.
    Return the slices to extract the region of interest (ROI).
    The length of the slices array is equal to the number of dimensions of the mask.
    """
    roi_center = find_region_center(binary_mask)
    starting_points = [
        _find_starting_point(coord, size, dim_size)
        for coord, dim_size in zip(roi_center, binary_mask.shape)
    ]
    ending_points = [
        sp + size for sp in starting_points
    ]
    slices = list(
        slice(sp, ep) for sp, ep in zip(starting_points, ending_points)
    )
    return slices
