import torch
import torch.nn as nn

__all__ = [
    "SphericalMask"
]


class SphericalMask(nn.Module):
    def __init__(
        self, 
        in_channels, 
        latent_dim,
        radius_fraction=1.0,
        device=None
    ) -> None:
        super(SphericalMask, self).__init__()
        
        self.device = device
        self.in_channels = in_channels
        self.latent_dim = latent_dim 
        
        self.input_shape = (self.in_channels, self.latent_dim, self.latent_dim, self.latent_dim)

        depth = height = width = self.latent_dim

        mask = torch.ones(self.input_shape, device=self.device)
        mask_center = (depth - 1) / 2  # Center of cube
        radius = (depth - 1) / 2  # Sphere radius

        radius_squared = (radius_fraction * radius) ** 2

        indices = torch.meshgrid(torch.arange(depth), torch.arange(height), torch.arange(width), indexing="ij")
        squared_distances = (indices[0] - mask_center) ** 2 + (indices[1] - mask_center) ** 2 + (indices[2] - mask_center) ** 2

        mask[:, squared_distances > radius_squared] = 0

        self.register_buffer('mask', mask)

    def forward(self, volume):
        return volume * self.mask

