import torch
import torch.nn.functional as F
import einops


class Resize3D:
    """Resize 3D spatial dimensions while keeping the time dimension intact."""

    def __init__(self, scale_factor=None, target_size=None, align_corners=False):
        assert (scale_factor is not None) or (target_size is not None), (
            "You must provide either scale_factor or target_size."
        )
        self.scale_factor = scale_factor
        self.target_size = target_size
        self.align_corners = align_corners

    def __call__(self, tensor):
        # Assumes shape [N, H, W, D, T]
        x, y, z = tensor.shape[1:4]

        if self.scale_factor is not None:
            new_size = (
                int(x * self.scale_factor),
                int(y * self.scale_factor),
                int(z * self.scale_factor),
            )
        else:
            new_size = self.target_size

        tensor = tensor.permute(0, 4, 1, 2, 3)  # N, T, H, W, D
        resized = F.interpolate(
            tensor, size=new_size, mode="trilinear", align_corners=self.align_corners
        )
        resized = resized.permute(0, 2, 3, 4, 1)  # N, H, W, D, T
        return resized


class NormalizeByRegion:
    """Region-wise normalization."""

    def __init__(self, full_dataset_tensor):
        dims_to_reduce = [0] + list(range(len(full_dataset_tensor.shape) - 2))
        self.mean = full_dataset_tensor.mean(dim=dims_to_reduce)
        self.std = full_dataset_tensor.std(dim=dims_to_reduce)
        self.std = torch.clamp(self.std, min=1e-6)

    def __call__(self, sample):
        return (sample - self.mean) / self.std


class NormalizeGlobal:
    """Global normalization."""

    def __init__(self, full_dataset_tensor):
        self.mean = full_dataset_tensor.mean()
        self.std = torch.clamp(full_dataset_tensor.std(), min=1e-6)

    def __call__(self, sample):
        return (sample - self.mean) / self.std


class NormalizeByTime:
    """Time-wise normalization."""

    def __init__(self, full_dataset_tensor):
        dims_to_reduce = list(range(len(full_dataset_tensor.shape) - 1))
        self.mean = full_dataset_tensor.mean(dim=dims_to_reduce)
        self.std = full_dataset_tensor.std(dim=dims_to_reduce)
        self.std = torch.clamp(self.std, min=1e-6)

    def __call__(self, sample):
        return (sample - self.mean) / self.std
