"""
Our code is based on the following code.
https://docs.monai.io/en/stable/_modules/monai/networks/nets/swin_unetr.html#SwinUNETR
https://github.com/Transconnectome/SwiFT/blob/main/project/module/models/swin4d_transformer_ver7.py
"""
import time
import itertools
from typing import Optional, Sequence, Tuple, Type
from scipy import ndimage
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import LayerNorm
import random

from monai.networks.blocks import MLPBlock as Mlp
from monai.networks.layers import DropPath, trunc_normal_
from monai.utils import ensure_tuple_rep, look_up_option, optional_import
from mamba_ssm import Mamba, Mamba2
from .patchembedding import PatchEmbed
# from .redundant_dropout import redundant_dropout

import nibabel as nib  # to load .nii atlas files
import matplotlib.pyplot as plt
import os
from scipy import ndimage
import random
rearrange, _ = optional_import("einops", name="rearrange")

__all__ = [
    "window_partition",
    "window_reverse",
    "WindowAttention4D",
    "SwinTransformerBlock4D",
    "PatchMergingV2",
    "MERGING_MODE",
    "BasicLayer",
    "NeuroSTORM",
    "NeuroSTORMMAE",
]


def window_partition(x, window_size):
    """window partition operation based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

    Partition tokens into their respective windows

     Args:
        x: input tensor (B, D, H, W, T, C)

        window_size: local window size.


    Returns:
        windows: (B*num_windows, window_size*window_size*window_size*window_size, C)
    """
    x_shape = x.size()

    b, d, h, w, t, c = x_shape
    x = x.view(
        b,
        d // window_size[0],  # number of windows in depth dimension
        window_size[0],  # window size in depth dimension
        h // window_size[1],  # number of windows in height dimension
        window_size[1],  # window size in height dimension
        w // window_size[2],  # number of windows in width dimension
        window_size[2],  # window size in width dimension
        t // window_size[3],  # number of windows in time dimension
        window_size[3],  # window size in time dimension
        c,
    )
    windows = (
        x.permute(0, 1, 3, 5, 7, 2, 4, 6, 8, 9)
        .contiguous()
        .view(-1, window_size[0] * window_size[1] * window_size[2] * window_size[3], c)
    )
    return windows


def window_partition_with_b(x, window_size):
    x_shape = x.size()

    b, d, h, w, t, c = x_shape
    x = x.view(
        b,
        d // window_size[0],  # number of windows in depth dimension
        window_size[0],  # window size in depth dimension
        h // window_size[1],  # number of windows in height dimension
        window_size[1],  # window size in height dimension
        w // window_size[2],  # number of windows in width dimension
        window_size[2],  # window size in width dimension
        t // window_size[3],  # number of windows in time dimension
        window_size[3],  # window size in time dimension
        c,
    )
    windows = (
        x.permute(0, 1, 3, 5, 7, 2, 4, 6, 8, 9)
        .contiguous()
        .view(b, -1, window_size[0] * window_size[1] * window_size[2] * window_size[3], c)
    )
    return windows


def window_reverse(windows, window_size, dims):
    """window reverse operation based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        windows: windows tensor (B*num_windows, window_size, window_size, C)
        window_size: local window size.
        dims: dimension values.

    Returns:
        x: (B, D, H, W, T, C)
    """

    b, d, h, w, t = dims
    x = windows.view(
        b,
        torch.div(d, window_size[0], rounding_mode="floor"),
        torch.div(h, window_size[1], rounding_mode="floor"),
        torch.div(w, window_size[2], rounding_mode="floor"),
        torch.div(t, window_size[3], rounding_mode="floor"),
        window_size[0],
        window_size[1],
        window_size[2],
        window_size[3],
        -1,
    )
    x = x.permute(0, 1, 5, 2, 6, 3, 7, 4, 8, 9).contiguous().view(b, d, h, w, t, -1)

    return x


def get_window_size(x_size, window_size, shift_size=None):
    """Computing window size based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        x_size: input size.
        window_size: local window size.
        shift_size: window shifting size.
    """

    use_window_size = list(window_size)
    if shift_size is not None:
        use_shift_size = list(shift_size)
    for i in range(len(x_size)):
        if x_size[i] <= window_size[i]:
            use_window_size[i] = x_size[i]
            if shift_size is not None:
                use_shift_size[i] = 0

    if shift_size is None:
        return tuple(use_window_size)
    else:
        return tuple(use_window_size), tuple(use_shift_size)


class WindowAttention4D(nn.Module):
    """
    Window based multi-head self attention module with relative position bias based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        window_size: Sequence[int],
        qkv_bias: bool = False,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ) -> None:
        """
        Args:
            dim: number of feature channels.
            num_heads: number of attention heads.
            window_size: local window size.
            qkv_bias: add a learnable bias to query, key, value.
            attn_drop: attention dropout rate.
            proj_drop: dropout rate of output.
        """

        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5
        mesh_args = torch.meshgrid.__kwdefaults__

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask):
        """Forward function.
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, N, N) or None
        """
        b_, n, c = x.shape
        qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        if mask is not None:
            nw = mask.shape[0]
            attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.to(attn.dtype).unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, n, n)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock4D(nn.Module):
    """
    Swin Transformer block based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        window_size: Sequence[int],
        shift_size: Sequence[int],
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        act_layer: str = "GELU",
        norm_layer: Type[LayerNorm] = nn.LayerNorm,
        use_checkpoint: bool = False,
    ) -> None:
        """
        Args:
            dim: number of feature channels.
            num_heads: number of attention heads.
            window_size: local window size.
            shift_size: window shift size.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            drop_path: stochastic depth rate.
            act_layer: activation layer.
            norm_layer: normalization layer.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
        """

        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.use_checkpoint = use_checkpoint

        self.norm1 = norm_layer(dim)

        self.mamba = Mamba(
                d_model=dim, # Model dimension d_model
                d_state=16,  # SSM state expansion factor
                d_conv=4,    # Local convolution width
                expand=2,    # Block expansion factor
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")

    def forward_part1(self, x, mask_matrix):
        b, d, h, w, t, c = x.shape
        window_size, shift_size = get_window_size((d, h, w, t), self.window_size, self.shift_size)
        x = self.norm1(x)
        pad_d0 = pad_h0 = pad_w0 = pad_t0 = 0
        pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
        pad_h1 = (window_size[1] - h % window_size[1]) % window_size[1]
        pad_w1 = (window_size[2] - w % window_size[2]) % window_size[2]
        pad_t1 = (window_size[3] - t % window_size[3]) % window_size[3]
        x = F.pad(x, (0, 0, pad_t0, pad_t1, pad_w0, pad_w1, pad_h0, pad_h1, pad_d0, pad_d1))  # last tuple first in
        _, dp, hp, wp, tp, _ = x.shape
        dims = [b, dp, hp, wp, tp]
        if any(i > 0 for i in shift_size):
            shifted_x = torch.roll(
                x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2], -shift_size[3]), dims=(1, 2, 3, 4)
            )
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None
        x_windows = window_partition(shifted_x, window_size)
        attn_windows = self.mamba(x_windows)
        attn_windows = attn_windows.view(-1, *(window_size + (c,)))
        shifted_x = window_reverse(attn_windows, window_size, dims)
        if any(i > 0 for i in shift_size):
            x = torch.roll(
                shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2], shift_size[3]), dims=(1, 2, 3, 4)
            )
        else:
            x = shifted_x

        if pad_d1 > 0 or pad_h1 > 0 or pad_w1 > 0 or pad_t1 > 0:
            x = x[:, :d, :h, :w, :t, :].contiguous()
        return x

    def forward_part2(self, x):
        x = self.drop_path(self.mlp(self.norm2(x)))
        return x

    def forward(self, x, mask_matrix):
        shortcut = x
        if self.use_checkpoint:
            x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
        else:
            x = self.forward_part1(x, mask_matrix)
        x = shortcut + self.drop_path(x)
        if self.use_checkpoint:
            x = x + checkpoint.checkpoint(self.forward_part2, x)
        else:
            x = x + self.forward_part2(x)
        return x

class PatchMergingV2(nn.Module):
    """
    Patch merging layer based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3, c_multiplier: int = 2
    ) -> None:
        """
        Args:
            dim: number of feature channels.
            norm_layer: normalization layer.
            spatial_dims: number of spatial dims.
        """

        super().__init__()
        self.dim = dim

        # Skip dimension reduction on the temporal dimension

        self.reduction = nn.Linear(8 * dim, c_multiplier * dim, bias=False)
        self.norm = norm_layer(8 * dim)

    def forward(self, x):
        x_shape = x.size()
        b, d, h, w, t, c = x_shape
        x = torch.cat([x[:, i::2, j::2, k::2, :, :] for i, j, k in itertools.product(range(2), range(2), range(2))], -1)

        x = self.norm(x)
        x = self.reduction(x)

        return x


class PatchExpanding(nn.Module):
    def __init__(
        self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3, c_multiplier: int = 2
    ) -> None:
        super().__init__()
        self.dim = dim

        self.expand = nn.Linear(dim, c_multiplier * c_multiplier * dim, bias=False)
        self.norm = norm_layer(dim // c_multiplier)

    def forward(self, x):
        x = self.expand(x)
        x = rearrange(x, 'B D H W T (P1 P2 P3 C) -> B (D P1) (H P2) (W P3) T C', P1=2, P2=2, P3=2)
        x = self.norm(x)

        return x


MERGING_MODE = {"mergingv2": PatchMergingV2}


def compute_mask(dims, window_size, shift_size, device):
    """Computing region masks based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer

     Args:
        dims: dimension values.
        window_size: local window size.
        shift_size: shift size.
        device: device.
    """

    cnt = 0

    d, h, w, t = dims
    img_mask = torch.zeros((1, d, h, w, t, 1), device=device)
    for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
        for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
            for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
                for t in slice(-window_size[3]), slice(-window_size[3], -shift_size[3]), slice(-shift_size[3], None):
                    img_mask[:, d, h, w, t, :] = cnt
                    cnt += 1

    mask_windows = window_partition(img_mask, window_size)
    mask_windows = mask_windows.squeeze(-1)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

    return attn_mask


class BasicLayer(nn.Module):
    """
    Basic Swin Transformer layer in one stage based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        window_size: Sequence[int],
        drop_path: list,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        norm_layer: Type[LayerNorm] = nn.LayerNorm,
        c_multiplier: int = 2,
        downsample: Optional[nn.Module] = None,
        use_checkpoint: bool = False
    ) -> None:
        """
        Args:
            dim: number of feature channels.
            depth: number of layers in each stage.
            num_heads: number of attention heads.
            window_size: local window size.
            drop_path: stochastic depth rate.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            norm_layer: normalization layer.
            downsample: an optional downsampling layer at the end of the layer.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
        """

        super().__init__()
        self.window_size = window_size
        self.shift_size = tuple(i // 2 for i in window_size)
        self.no_shift = tuple(0 for i in window_size)
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        self.blocks = nn.ModuleList(
            [
                SwinTransformerBlock4D(
                    dim=dim,
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                    norm_layer=norm_layer,
                    use_checkpoint=use_checkpoint
                )
                for i in range(depth)
            ]
        )
        self.downsample = downsample
        if callable(self.downsample):
            self.downsample = downsample(
                dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size), c_multiplier=c_multiplier
            )

    def forward(self, x):
        b, c, d, h, w, t = x.size()
        window_size, shift_size = get_window_size((d, h, w, t), self.window_size, self.shift_size)
        x = rearrange(x, "b c d h w t -> b d h w t c")
        dp = int(np.ceil(d / window_size[0])) * window_size[0]
        hp = int(np.ceil(h / window_size[1])) * window_size[1]
        wp = int(np.ceil(w / window_size[2])) * window_size[2]
        tp = int(np.ceil(t / window_size[3])) * window_size[3]
        attn_mask = compute_mask([dp, hp, wp, tp], window_size, shift_size, x.device)
        # attn_mask = redundant_dropout(attn_mask, 0.1)
        for blk in self.blocks:
            x = blk(x, attn_mask)
        x = x.view(b, d, h, w, t, -1)
        if self.downsample is not None:
            x = self.downsample(x)
        x = rearrange(x, "b d h w t c -> b c d h w t")

        return x
    

class BasicLayerUp(nn.Module):
    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        window_size: Sequence[int],
        drop_path: list,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        norm_layer: Type[LayerNorm] = nn.LayerNorm,
        c_multiplier: int = 2,
        upsample: Optional[nn.Module] = None,
        use_checkpoint: bool = False
    ) -> None:
        super().__init__()
        self.window_size = window_size
        self.shift_size = tuple(i // 2 for i in window_size)
        self.no_shift = tuple(0 for i in window_size)
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        self.blocks = nn.ModuleList(
            [
                SwinTransformerBlock4D(
                    dim=dim,
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                    norm_layer=norm_layer,
                    use_checkpoint=use_checkpoint
                )
                for i in range(depth)
            ]
        )
        self.upsample = upsample
        if callable(self.upsample):
            self.upsample = upsample(
                dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size), c_multiplier=c_multiplier
            )

    def forward(self, x):
        b, c, d, h, w, t = x.size()
        window_size, shift_size = get_window_size((d, h, w, t), self.window_size, self.shift_size)
        x = rearrange(x, "b c d h w t -> b d h w t c")
        dp = int(np.ceil(d / window_size[0])) * window_size[0]
        hp = int(np.ceil(h / window_size[1])) * window_size[1]
        wp = int(np.ceil(w / window_size[2])) * window_size[2]
        tp = int(np.ceil(t / window_size[3])) * window_size[3]
        attn_mask = compute_mask([dp, hp, wp, tp], window_size, shift_size, x.device)
        for blk in self.blocks:
            x = blk(x, attn_mask)
        x = x.view(b, d, h, w, t, -1)
        if self.upsample is not None:
            x = self.upsample(x)
        x = rearrange(x, "b d h w t c -> b c d h w t")

        return x


# Basic layer for full attention,
# the only difference is that there is no window shifting
class BasicLayer_FullAttention(nn.Module):
    """
    Basic Swin Transformer layer in one stage based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        dim: int,
        depth: int,
        num_heads: int,
        window_size: Sequence[int],
        drop_path: list,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        norm_layer: Type[LayerNorm] = nn.LayerNorm,
        c_multiplier: int = 2,
        downsample: Optional[nn.Module] = None,
        use_checkpoint: bool = False
    ) -> None:
        """
        Args:
            dim: number of feature channels.
            depth: number of layers in each stage.
            num_heads: number of attention heads.
            window_size: local window size.
            drop_path: stochastic depth rate.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            norm_layer: normalization layer.
            downsample: an optional downsampling layer at the end of the layer.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
        """

        super().__init__()
        self.window_size = window_size
        self.shift_size = tuple(i // 2 for i in window_size)
        self.no_shift = tuple(0 for i in window_size)
        self.depth = depth
        self.use_checkpoint = use_checkpoint
        self.blocks = nn.ModuleList(
            [
                SwinTransformerBlock4D(
                    dim=dim,
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=self.no_shift,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                    norm_layer=norm_layer,
                    use_checkpoint=use_checkpoint
                )
                for i in range(depth)
            ]
        )
        self.downsample = downsample
        if callable(self.downsample):
            self.downsample = downsample(
                dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size), c_multiplier=c_multiplier
            )

    def forward(self, x):
        b, c, d, h, w, t = x.size()
        window_size, shift_size = get_window_size((d, h, w, t), self.window_size, self.shift_size)
        x = rearrange(x, "b c d h w t -> b d h w t c")
        dp = int(np.ceil(d / window_size[0])) * window_size[0]
        hp = int(np.ceil(h / window_size[1])) * window_size[1]
        wp = int(np.ceil(w / window_size[2])) * window_size[2]
        tp = int(np.ceil(t / window_size[3])) * window_size[3]
        attn_mask = None
        for blk in self.blocks:
            x = blk(x, attn_mask)
        x = x.view(b, d, h, w, t, -1)
        if self.downsample is not None:
            x = self.downsample(x)
        x = rearrange(x, "b d h w t c -> b c d h w t")

        return x


class PositionalEmbedding(nn.Module):
    """
    Absolute positional embedding module
    """

    def __init__(
        self, dim: int, patch_dim: tuple
    ) -> None:
        """
        Args:
            dim: number of feature channels.
            patch_num: total number of patches per time frame
            time_num: total number of time frames
        """

        super().__init__()
        self.dim = dim
        self.patch_dim = patch_dim
        d, h, w, t = patch_dim
        self.pos_embed = nn.Parameter(torch.zeros(1, dim, d, h, w, 1))
        self.time_embed = nn.Parameter(torch.zeros(1, dim, 1, 1, 1, t))

        
        trunc_normal_(self.pos_embed, std=0.02)
        
        trunc_normal_(self.time_embed, std=0.02)


    def forward(self, x):
        b, c, d, h, w, t = x.shape

        x = x + self.pos_embed
        # only add time_embed up to the time frame of the input in case the input size changes
        x = x + self.time_embed[:, :, :, :, :, :t]

        return x

class NeuroSTORM(nn.Module):
    """
    Swin Transformer based on: "Liu et al.,
    Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
    <https://arxiv.org/abs/2103.14030>"
    https://github.com/microsoft/Swin-Transformer
    """

    def __init__(
        self,
        img_size: Tuple,
        in_chans: int,
        embed_dim: int,
        window_size: Sequence[int],
        first_window_size: Sequence[int],
        patch_size: Sequence[int],
        depths: Sequence[int],
        num_heads: Sequence[int],
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer: Type[LayerNorm] = nn.LayerNorm,
        patch_norm: bool = False,
        use_checkpoint: bool = False,
        spatial_dims: int = 4,
        c_multiplier: int = 2,
        last_layer_full_MSA: bool = False,
        downsample="mergingv2",
        num_classes=2,
        **kwargs,
    ) -> None:
        """
        Args:
            in_chans: dimension of input channels.
            embed_dim: number of linear projection output channels.
            window_size: local window size.
            patch_size: patch size.
            depths: number of layers in each stage.
            num_heads: number of attention heads.
            mlp_ratio: ratio of mlp hidden dim to embedding dim.
            qkv_bias: add a learnable bias to query, key, value.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            drop_path_rate: stochastic depth rate.
            norm_layer: normalization layer.
            patch_norm: add normalization after patch embedding.
            use_checkpoint: use gradient checkpointing for reduced memory usage.
            spatial_dims: spatial dimension.
            downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
                user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
                The default is currently `"merging"` (the original version defined in v0.9.0).


            c_multiplier: multiplier for the feature length after patch merging
        """

        super().__init__()
        img_size = ensure_tuple_rep(img_size, spatial_dims)
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.window_size = window_size
        self.first_window_size = first_window_size
        self.patch_size = patch_size
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=self.patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None,  # type: ignore
            flatten=False,
            spatial_dims=spatial_dims,
        )
        grid_size = self.patch_embed.grid_size
        self.grid_size = grid_size
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        #patch_num = int((img_size[0]/patch_size[0]) * (img_size[1]/patch_size[1]) * (img_size[2]/patch_size[2]))
        #time_num = int(img_size[3]/patch_size[3])
        patch_dim =  ((img_size[0]//patch_size[0]), (img_size[1]//patch_size[1]), (img_size[2]//patch_size[2]), (img_size[3]//patch_size[3]))

        #print img, patch size, patch dim
        print("img_size: ", img_size)
        print("patch_size: ", patch_size)
        print("patch_dim: ", patch_dim)
        self.pos_embeds = nn.ModuleList()
        pos_embed_dim = embed_dim
        for i in range(self.num_layers):
            self.pos_embeds.append(PositionalEmbedding(pos_embed_dim, patch_dim))
            pos_embed_dim = pos_embed_dim * c_multiplier
            patch_dim = (patch_dim[0]//2, patch_dim[1]//2, patch_dim[2]//2, patch_dim[3])

        # build layer
        self.layers = nn.ModuleList()
        down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample
    
        layer = BasicLayer(
            dim=int(embed_dim),
            depth=depths[0],
            num_heads=num_heads[0],
            window_size=self.first_window_size,
            drop_path=dpr[sum(depths[:0]) : sum(depths[: 0 + 1])],
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop=drop_rate,
            attn_drop=attn_drop_rate,
            norm_layer=norm_layer,
            c_multiplier=c_multiplier,
            downsample=down_sample_mod if 0 < self.num_layers - 1 else None,
            use_checkpoint=use_checkpoint
        )
        self.layers.append(layer)

        # exclude last layer
        for i_layer in range(1, self.num_layers - 1):
            layer = BasicLayer(
                dim=int(embed_dim * (c_multiplier**i_layer)),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=self.window_size,
                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                c_multiplier=c_multiplier,
                downsample=down_sample_mod if i_layer < self.num_layers - 1 else None,
                use_checkpoint=use_checkpoint
            )
            self.layers.append(layer)

        if not last_layer_full_MSA:
            layer = BasicLayer(
                dim=int(embed_dim * c_multiplier ** (self.num_layers - 1)),
                depth=depths[(self.num_layers - 1)],
                num_heads=num_heads[(self.num_layers - 1)],
                window_size=self.window_size,
                drop_path=dpr[sum(depths[: (self.num_layers - 1)]) : sum(depths[: (self.num_layers - 1) + 1])],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                c_multiplier=c_multiplier,
                downsample=None,
                use_checkpoint=use_checkpoint
            )
            self.layers.append(layer)

        else:
            #################Full MSA for last layer#####################

            self.last_window_size = (
                self.grid_size[0] // int(2 ** (self.num_layers - 1)),
                self.grid_size[1] // int(2 ** (self.num_layers - 1)),
                self.grid_size[2] // int(2 ** (self.num_layers - 1)),
                self.window_size[3],
            )

            layer = BasicLayer_FullAttention(
                dim=int(embed_dim * c_multiplier ** (self.num_layers - 1)),
                depth=depths[(self.num_layers - 1)],
                num_heads=num_heads[(self.num_layers - 1)],
                # change the window size to the entire grid size
                window_size=self.last_window_size,
                drop_path=dpr[sum(depths[: (self.num_layers - 1)]) : sum(depths[: (self.num_layers - 1) + 1])],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                c_multiplier=c_multiplier,
                downsample=None,
                use_checkpoint=use_checkpoint
            )
            self.layers.append(layer)

    def forward(self, x):
        x = x.float()
        # torch.Size([16, 1, 96, 96, 96, 20])
        x = self.patch_embed(x)
        # torch.Size([16, 36, 16, 16, 16, 20])
        x = self.pos_drop(x)

        for i in range(self.num_layers):
            x = self.pos_embeds[i](x)
            x = self.layers[i](x.contiguous())
            # torch.Size([16, 72, 8, 8, 8, 20])
            # torch.Size([16, 144, 4, 4, 4, 20])
            # torch.Size([16, 288, 2, 2, 2, 20])
            # torch.Size([16, 288, 2, 2, 2, 20])

        return x


class NeuroSTORMMAE(nn.Module):
    def __init__(
        self,
        img_size: Tuple,
        in_chans: int,
        embed_dim: int,
        window_size: Sequence[int],
        first_window_size: Sequence[int],
        patch_size: Sequence[int],
        depths: Sequence[int],
        num_heads: Sequence[int],
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        norm_layer: Type[LayerNorm] = nn.LayerNorm,
        patch_norm: bool = False,
        use_checkpoint: bool = False,
        spatial_dims: int = 4,
        c_multiplier: int = 2,
        last_layer_full_MSA: bool = False,
        downsample="mergingv2",
        mask_ratio: float = 0.1, 
        spatial_mask="random",   
        time_mask="random",
        **kwargs,
    ) -> None:
        super().__init__()
        img_size = ensure_tuple_rep(img_size, spatial_dims)
        self.mask_ratio = mask_ratio
        self.spatial_mask = spatial_mask
        self.time_mask = time_mask
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.window_size = window_size
        self.tube_window_size = [1, 1, 1, window_size[-1]]
        self.first_window_size = first_window_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=self.patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None,  # type: ignore
            flatten=False,
            spatial_dims=spatial_dims,
        )
        if self.spatial_mask == 'atlas' and self.time_mask == 'tube':
            print("Loading atlas.....")
            atlas_path = kwargs.get("atlas_path", "./datasets/atlas/aal3_1mm.nii.gz")  # allow override via args
            atlas_data = nib.load(atlas_path).get_fdata().astype(np.int32)  # shape (D, H, W)
            # Brain Regions in AAL atlas
            # Frontal Regions (including Prefrontal Cortex)
            frontal = [
                1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                13, 14, 15, 16, 19, 20, 21, 22, 23, 24,
                25, 26, 27, 28, 29, 30, 31, 32
            ]
            # Occipital Regions (Visual Cortex)
            occipital = [
                47, 48, 49, 50, 51, 52, 53, 54, 55, 56,
                57, 58, 59, 60
            ]
            # Parietal Regions
            parietal = [
                61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 
                71, 72, 73, 74
            ]
            # Limbic Regions (Cingulate, Amygdala, Hippocampus, Parahippocampal)
            limbic = [
                35, 36, 37, 38, 39, 40, 41, 42, 
                45, 46, 151, 152, 153, 154, 155, 156
            ]
            # Temporal Regions
            temporal = [
                17, 18, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94
            ]
            # Subcortical Regions (Basal Ganglia, thalamus and subcortical Nuclei)
            subcortical = [
                75, 76, 77, 78, 79, 80, 81, 82,
                157, 158, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130,
                131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
                141, 142, 143, 144, 145, 146, 147, 148, 149, 150
            ]
            # Cerebellar Regions
            cerebellum = [
                95, 96, 97, 98, 99, 100, 101, 102,
                103, 104, 105, 106, 107, 108, 109, 110,
                111, 112, 113, 114, 115, 116, 117, 118,
                119, 120
            ]
            # Brainstem Regions (Midbrain and related nuclei)
            brainstem = [
                159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170
            ]

            atlas_tensor = torch.from_numpy(atlas_data).unsqueeze(0).unsqueeze(0).float()  # [1, 1, D, H, W]

            # Resize to match fMRI input dimensions (96, 96, 96)
            resized_atlas = F.interpolate(atlas_tensor, size=(96, 96, 96), mode='nearest')
            self.atlas_patch = resized_atlas[0, 0].long()  # [96, 96, 96]
            unique_rois = torch.unique(self.atlas_patch)
            unique_rois = unique_rois[unique_rois != 0]  
            region_arg = kwargs.get("region_of_interest", "random").lower()
            roi_name_map = {
                "frontal": frontal,
                "occipital": occipital,
                "parietal": parietal,
                "limbic": limbic,
                "temporal": temporal,
                "subcortical": subcortical,
                "cerebellum": cerebellum,
                "brainstem": brainstem,
            }
            if region_arg == "random":
                n_rois_to_select = kwargs.get("n_mask_rois", 50)
                selected_rois = unique_rois[torch.randperm(len(unique_rois))[:n_rois_to_select]]
                self.register_buffer("selected_rois", selected_rois, persistent=True)
                print("Number of ROI's Masked:", n_rois_to_select)
            else:
                selected_names = [r.strip() for r in region_arg.split(",")]
                print("Masked Regions:",selected_names)
                rois = []
                for name in selected_names:
                    if name in roi_name_map:
                        rois.extend(roi_name_map[name])
                    else:
                        print(f"Unknown region name: {name} — skipping.")
                if not rois:
                    raise ValueError("No valid region names found in --region_of_interest.")
                self.selected_rois = torch.tensor(sorted(set(rois)), device="cuda")

            print("Masked ROIs:", self.selected_rois.tolist())

        window_volume = window_size[0] * window_size[1] * window_size[2] * window_size[3]
        self.mask_token = nn.Parameter(torch.zeros([1, embed_dim], dtype=torch.float32))
    
        grid_size = self.patch_embed.grid_size
        self.grid_size = grid_size
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]


        patch_dim = ((img_size[0]//patch_size[0]), (img_size[1]//patch_size[1]), (img_size[2]//patch_size[2]), (img_size[3]//patch_size[3]))

        #print img, patch size, patch dim
        print("img_size: ", img_size)
        print("patch_size: ", patch_size)
        print("patch_dim: ", patch_dim)
        self.pos_embeds = nn.ModuleList()
        pos_embed_dim = embed_dim
        for i in range(self.num_layers):
            self.pos_embeds.append(PositionalEmbedding(pos_embed_dim, patch_dim))
            pos_embed_dim = pos_embed_dim * c_multiplier
            patch_dim = (patch_dim[0]//2, patch_dim[1]//2, patch_dim[2]//2, patch_dim[3])

        # build layer
        self.layers = nn.ModuleList()
        down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample
    
        layer = BasicLayer(
            dim=int(embed_dim),
            depth=depths[0],
            num_heads=num_heads[0],
            window_size=self.first_window_size,
            drop_path=dpr[sum(depths[:0]) : sum(depths[: 0 + 1])],
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop=drop_rate,
            attn_drop=attn_drop_rate,
            norm_layer=norm_layer,
            c_multiplier=c_multiplier,
            downsample=down_sample_mod if 0 < self.num_layers - 1 else None,
            use_checkpoint=use_checkpoint
        )
        self.layers.append(layer)

        # exclude last layer
        for i_layer in range(1, self.num_layers - 1):
            layer = BasicLayer(
                dim=int(embed_dim * (c_multiplier**i_layer)),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=self.window_size,
                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                c_multiplier=c_multiplier,
                downsample=down_sample_mod if i_layer < self.num_layers - 1 else None,
                use_checkpoint=use_checkpoint
            )
            self.layers.append(layer)

        if not last_layer_full_MSA:
            layer = BasicLayer(
                dim=int(embed_dim * c_multiplier ** (self.num_layers - 1)),
                depth=depths[(self.num_layers - 1)],
                num_heads=num_heads[(self.num_layers - 1)],
                window_size=self.window_size,
                drop_path=dpr[sum(depths[: (self.num_layers - 1)]) : sum(depths[: (self.num_layers - 1) + 1])],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                c_multiplier=c_multiplier,
                downsample=None,
                use_checkpoint=use_checkpoint
            )
            self.layers.append(layer)
        else:
            #################Full MSA for last layer#####################

            self.last_window_size = (
                self.grid_size[0] // int(2 ** (self.num_layers - 1)),
                self.grid_size[1] // int(2 ** (self.num_layers - 1)),
                self.grid_size[2] // int(2 ** (self.num_layers - 1)),
                self.window_size[3],
            )

            layer = BasicLayer_FullAttention(
                dim=int(embed_dim * c_multiplier ** (self.num_layers - 1)),
                depth=depths[(self.num_layers - 1)],
                num_heads=num_heads[(self.num_layers - 1)],
                window_size=self.last_window_size,
                drop_path=dpr[sum(depths[: (self.num_layers - 1)]) : sum(depths[: (self.num_layers - 1) + 1])],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                c_multiplier=c_multiplier,
                downsample=None,
                use_checkpoint=use_checkpoint
            )
            self.layers.append(layer)
        
        self.first_patch_expanding = PatchExpanding(dim=embed_dim * 2 ** (len(depths) - 1), norm_layer=norm_layer)
        self.layers_up = nn.ModuleList()
        
        for i_layer in range(self.num_layers-1):
            i_layer = len(depths) - i_layer - 2
            layer = BasicLayerUp(
                dim=int(embed_dim * (c_multiplier**i_layer)),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=self.window_size,
                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                norm_layer=norm_layer,
                c_multiplier=c_multiplier,
                upsample=PatchExpanding if i_layer > 0 else None,
                use_checkpoint=use_checkpoint
            )
            self.layers_up.append(layer)
        self.norm_up = norm_layer(embed_dim)
        self.decoder_pred = nn.Linear(embed_dim * 2 ** (len(depths) - 1) // 8, patch_size[0] ** 3 * in_chans, bias=True)

    def random_masking(self, sequence):
        if self.spatial_mask == 'random' and self.time_mask == 'random':
            sequence = rearrange(sequence, 'B C D H W T -> B D H W T C')
            B, D, H, W, T, C = sequence.shape
            sequence = rearrange(sequence, 'B D H W T C -> B (D H W T) C')
            B, N, C = sequence.shape

            overall_mask = np.zeros([B, N])
            for i in range(B):
                num_mask = int(N * self.mask_ratio)
                num_unmask = N - num_mask
                mask = np.hstack([
                    np.zeros(num_unmask),
                    np.ones(num_mask),
                ])
                np.random.shuffle(mask)
                overall_mask[i, :] = mask

            overall_mask = torch.from_numpy(overall_mask).to(torch.bool)
            sequence = rearrange(sequence, 'B N C -> (B N) C')
            overall_mask = rearrange(overall_mask, 'B N -> (B N)')
            sequence[overall_mask] = self.mask_token
            overall_mask = rearrange(overall_mask, '(B N) -> B N', B=B)
            overall_mask = overall_mask.cuda()
            sequence = sequence.reshape(B, D, H, W, T, C)
            
            new_sequence = rearrange(sequence, 'B D H W T C -> B C D H W T')
        elif self.spatial_mask == 'atlas' and self.time_mask == 'random':
            raise NotImplementedError
        elif self.spatial_mask == 'window' and self.time_mask == 'random':
            sequence = rearrange(sequence, 'B C D H W T -> B D H W T C')
            B, D, H, W, T, C = sequence.shape
            dims = (B, D, H, W, T)
            windows = window_partition_with_b(sequence, self.window_size)
            B, N, window_volume, D = windows.shape

            overall_mask = np.zeros([B, N])
            for i in range(B):
                num_mask = int(N * self.mask_ratio)
                num_unmask = N - num_mask
                mask = np.hstack([
                    np.zeros(num_unmask),
                    np.ones(num_mask),
                ])
                np.random.shuffle(mask)
                overall_mask[i, :] = mask

            overall_mask = torch.from_numpy(overall_mask).to(torch.bool)
            windows = rearrange(windows, 'B N W C -> (B N) W C')
            overall_mask = rearrange(overall_mask, 'B N -> (B N)')
            windows[overall_mask] = self.mask_token
            overall_mask = rearrange(overall_mask, '(B N) -> B N', B=B)
            overall_mask = overall_mask.cuda()
            new_sequence = window_reverse(windows, self.window_size, dims)
            new_sequence = rearrange(new_sequence, 'B D H W T C -> B C D H W T')

        elif self.spatial_mask == 'random' and self.time_mask == 'tube':

            sequence = rearrange(sequence, 'B C D H W T -> B D H W T C')
            B, D, H, W, T, C = sequence.shape
            dims = (B, D, H, W, T)
            windows = window_partition_with_b(sequence, self.tube_window_size)
            B, N, window_volume, D = windows.shape

            overall_mask = np.zeros([B, N])
            for i in range(B):
                num_mask = int(N * self.mask_ratio)
                num_unmask = N - num_mask
                mask = np.hstack([
                    np.zeros(num_unmask),
                    np.ones(num_mask),
                ])
                np.random.shuffle(mask)
                overall_mask[i, :] = mask

            overall_mask = torch.from_numpy(overall_mask).to(torch.bool)
            windows = rearrange(windows, 'B N W C -> (B N) W C')
            overall_mask = rearrange(overall_mask, 'B N -> (B N)')
            windows[overall_mask] = self.mask_token
            overall_mask = rearrange(overall_mask, '(B N) -> B N', B=B)
            overall_mask = overall_mask.cuda()
            new_sequence = window_reverse(windows, self.tube_window_size, dims)
            new_sequence = rearrange(new_sequence, 'B D H W T C -> B C D H W T')
        
        elif self.spatial_mask == 'atlas' and self.time_mask == 'tube':
            x = sequence  
            B, C, D_p, H_p, W_p, T = x.shape
            x = rearrange(x, 'B C D H W T -> B D H W T C')  
            atlas = self.atlas_patch.unsqueeze(0).unsqueeze(0).float().to(x.device)  
            atlas_down = F.adaptive_max_pool3d(atlas, output_size=(D_p, H_p, W_p))[0, 0].long().cuda()

            selected_rois = self.selected_rois.to(x.device)

            roi_mask_3d = torch.isin(atlas_down, selected_rois)  
            patch_mask_3d = roi_mask_3d
            patch_mask_5d = patch_mask_3d[None, :, :, :, None]             
            expanded_mask = patch_mask_5d.expand(B, D_p, H_p, W_p, T)     
            expanded_mask = expanded_mask.unsqueeze(-1)                    
            expanded_mask = expanded_mask.expand(-1, -1, -1, -1, -1, x.shape[-1]) 

            mask_token = self.mask_token.view(1, 1, 1, 1, 1, -1)
            x = torch.where(expanded_mask, mask_token, x)
            patch_mask_batched = patch_mask_5d.expand(B, D_p, H_p, W_p, T)
            tube_mask = patch_mask_batched.unfold(dimension=4, size=4, step=4).any(dim=-1) 
            mask = tube_mask.reshape(B, -1)  
            new_sequence = rearrange(x, 'B D H W T C -> B C D H W T')
            return new_sequence, mask


        elif self.spatial_mask == 'window' and self.time_mask == 'tube':
            raise NotImplementedError
        else:
            print("Invalid mask type")
            import ipdb; ipdb.set_trace()
        return new_sequence, overall_mask

    def forward_encoder(self, x):
        x = self.patch_embed(x)          
        x, mask = self.random_masking(x)  
        for i in range(self.num_layers):
            x = self.pos_embeds[i](x)
            x = self.layers[i](x.contiguous())

        return x, mask

    def forward_decoder(self, x):
        x = rearrange(x, 'B C D H W T -> B D H W T C')
        x = self.first_patch_expanding(x)
        x = rearrange(x, 'B D H W T C -> B C D H W T')

        for layer in self.layers_up:
            x = layer(x)

        x = rearrange(x, 'B C D H W T -> B D H W T C')
        x = self.norm_up(x)
        x = self.decoder_pred(x)
        x = rearrange(x, 'B D H W T (P1 P2 P3 C) -> B C (D P1) (H P2) (W P3) T', C=self.in_chans, P1=self.patch_size[0], P2=self.patch_size[1], P3=self.patch_size[2])
        return x
    
    def patchify(self, x):
        B, C, H, W, D, T = x.shape
        pH, pW, pD = self.grid_size
        sH, sW, sD, sT = self.patch_size
        embed_dim = self.patch_size[0] * self.patch_size[1] * self.patch_size[2] * self.patch_size[3]

        x = x.view(B, C, pH, sH, pW, sW, pD, sD, -1, sT)
        x = x.permute(0, 2, 4, 6, 8, 3, 5, 7, 9, 1).contiguous().view(-1, sH * sW * sD * sT * C)
        x = x.view(B, pH, pW, pD, -1, embed_dim).contiguous()

        return x
    
    def forward_loss(self, x, pred, mask):
        if self.spatial_mask == 'random' and self.time_mask == 'random':
            x_patch = self.patchify(x)
            pred_patch = self.patchify(pred)
            B, pD, pH, pW, T, C = x_patch.shape
            N = pD * pH * pW * T

            x_patch = x_patch.view(B, N, C)
            pred_patch = pred_patch.view(B, N, C)

            loss = (x_patch - pred_patch) ** 2
            loss = loss.mean(dim=-1)            

            return (loss * mask).sum() / mask.sum()
        elif self.spatial_mask == 'window' and self.time_mask == 'random':
            x_patch = self.patchify(x)
            pred_patch = self.patchify(pred)
            x_windows = window_partition_with_b(x_patch, self.window_size)
            pred_windows = window_partition_with_b(pred_patch, self.window_size)
            loss = (x_windows - pred_windows) ** 2
            loss = loss.mean(dim=-1)
            loss = loss.mean(dim=-1)
            loss = (loss * mask).sum() / mask.sum()
        elif self.spatial_mask == 'random' and self.time_mask == 'tube':
            x_patch = self.patchify(x)
            pred_patch = self.patchify(pred)

            x_tubes = window_partition_with_b(x_patch, self.tube_window_size)         
            pred_tubes = window_partition_with_b(pred_patch, self.tube_window_size)   
            
            loss = (x_tubes - pred_tubes) ** 2            
            loss = loss.mean(dim=-1)                     
            loss = loss.mean(dim=-1)                      
            loss = (loss * mask).sum() / mask.sum()
        elif self.spatial_mask == 'atlas' and self.time_mask == 'tube':
            x_patch = self.patchify(x)         
            pred_patch = self.patchify(pred)   

            x_tubes = window_partition_with_b(x_patch, self.tube_window_size)        
            pred_tubes = window_partition_with_b(pred_patch, self.tube_window_size)   
            loss = (x_tubes - pred_tubes) ** 2       
            loss = loss.mean(dim=-1).mean(dim=-1)    
            loss = (loss * mask).sum() / mask.sum()
        return loss.mean()

    def forward(self, x):
        if type(x) == list:
            x = x[0]

        x = x.float()

        latent, mask = self.forward_encoder(x)
        pred = self.forward_decoder(latent)
        loss = self.forward_loss(x, pred, mask)
        return [pred, mask], loss




 