import time, torch, torch.nn as nn, kornia as K

# ---------------------------------------------------------------------------
def jpeg_single(img: torch.Tensor, q: float) -> torch.Tensor:
    q_tensor = torch.tensor([q], dtype=torch.float32, device=img.device)  # shape = [1]
    return K.enhance.jpeg_codec_differentiable(img.unsqueeze(0), q_tensor).squeeze(0)

# ---------------------------------------------------------------------------
class JPEGSlow(nn.Module):
    def __init__(self, qmin=40, qmax=80, passthrough=True):
        super().__init__(); self.qmin, self.qmax, self.p = qmin, qmax, passthrough
    def forward(self, x, mask=None, quality=None):
        q = float(quality) if quality is not None \
            else float(torch.randint(self.qmin, self.qmax + 1, ()).item())
        x = x.clamp_(0, 1)
        for i in range(x.size(0)):                       # Python loop ← slow
            y = jpeg_single(x[i], q)
            x[i] = (y - x[i]).detach() + x[i] if self.p else y
        return x, mask

# ---------------------------------------------------------------------------
class JPEGFast(nn.Module):
    def __init__(self, qmin=40, qmax=80, passthrough=True):
        super().__init__()
        self.qmin, self.qmax, self.passthrough = qmin, qmax, passthrough

    def forward(self, x, mask=None, quality=None):
        x = x.clamp_(0, 1)
        B = x.size(0)

        # quality tensor on the image device
        if quality is None:       # random per-batch scalar
            q = torch.randint(self.qmin, self.qmax + 1, (1,),
                              device=x.device, dtype=torch.float32).repeat(B)
        else:                     # fixed quality
            q = torch.full((B,), float(quality), device=x.device)

        y = K.enhance.jpeg_codec_differentiable(x, q)

        if self.passthrough:
            y = (y - x).detach() + x
        return y, mask

# ---------------------------------------------------------------------------
def bench(mod, img):
    torch.cuda.synchronize(); t0 = time.time()
    out, _ = mod(img.clone())
    torch.cuda.synchronize(); return time.time() - t0, out

if __name__ == "__main__":
    dev = 'cuda' if torch.cuda.is_available() else 'cpu'
    B, C, H, W = 128, 3, 256, 256
    img = torch.rand(B, C, H, W, device=dev)
    slow, fast = JPEGSlow().to(dev), JPEGFast().to(dev)

    for _ in range(3): slow(img); fast(img)            # warm-up

    t_slow, o_slow = bench(slow, img)
    t_fast, o_fast = bench(fast, img)

    print(f'Slow  : {t_slow*1e3:.1f} ms')
    print(f'Fast  : {t_fast*1e3:.1f} ms')
    print(f'Speed-up ×{t_slow/t_fast:.1f}')
    print(f'MAE   : {(o_slow - o_fast).abs().mean().item():.6f}')

