import torch
import torch.nn.functional as F
import torchvision.transforms.v2 as T


class LowResolution:
    def __init__(self, scale=8, quality=10):
        self.scale = scale
        self.jpeg = T.JPEG(quality=quality)

    def __call__(self, img: torch.Tensor) -> torch.Tensor:
        # Ensure 3D or 4D input (C, H, W) or (B, C, H, W)
        if img.ndim == 3:
            img = img.unsqueeze(0)  # Add batch dimension

        # Scale to uint8 for JPEG, apply compression, then rescale to float
        img = (img * 255).clamp(0, 255).to(torch.uint8)
        img = self.jpeg(img)
        img = img.to(torch.float32) / 255.0

        # Downscale
        img = F.interpolate(
            img, mode="bicubic", scale_factor=1 / self.scale, antialias=True
        )

        # Upscale
        img = F.interpolate(
            img, mode="bicubic", scale_factor=self.scale, antialias=True
        )

        return img.squeeze(0) if img.shape[0] == 1 else img
