# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import io
from PIL import Image

import torch
import torchvision.transforms as transforms
import kornia

# This is the new, Kornia-powered replacement function.
def kornia_jpeg_compress(image: torch.Tensor, quality: int) -> torch.Tensor:
    """
    Compress a PyTorch image or a batch of images using a differentiable
    JPEG approximation from Kornia.

    This function maintains the same calling signature as the original PIL-based
    function but adds support for batch processing and GPU acceleration.

    Parameters:
        image (torch.Tensor): The input image tensor. Can be a single image
                              of shape (C, H, W) or a batch of images of
                              shape (B, C, H, W). Values must be in [0, 1].
        quality (int): The JPEG quality factor (1-100).

    Returns:
        torch.Tensor: The compressed image or batch of images.
    """
    # Kornia expects a float tensor in the range [0, 1], which matches your original assert.
    assert image.min() >= 0 and image.max() <= 1, \
        f'Image pixel values must be in the range [0, 1], got [{image.min()}, {image.max()}]'

    # --- Smart Handling for Single Images vs. Batches ---
    is_single_image = image.dim() == 3
    if is_single_image:
        # If a single image (C, H, W), add a batch dimension to make it (1, C, H, W)
        image = image.unsqueeze(0)
    
    # Kornia's quality parameter must be a tensor on the same device as the image
    quality_tensor = torch.tensor(float(quality), device=image.device)
    
    # Initialize the Kornia augmenter. It's fast, so doing this on-the-fly is fine.
    # It will run on whichever device the `image` tensor is on (CPU or GPU).
    jpeg_augmenter = kornia.augmentation.RandomJPEG(jpeg_quality=quality_tensor)
    
    # Apply the compression to the whole batch at once
    compressed_image = jpeg_augmenter(image)
    
    if is_single_image:
        # If the input was a single image, remove the batch dimension before returning
        compressed_image = compressed_image.squeeze(0)
        
    return compressed_image

def jpeg_compress(image: torch.Tensor, quality: int) -> torch.Tensor:
    """
    Compress a PyTorch image using JPEG compression and return as a PyTorch tensor.

    Parameters:
        image (torch.Tensor): The input image tensor of shape 3xhxw.
        quality (int): The JPEG quality factor.

    Returns:
        torch.Tensor: The compressed image as a PyTorch tensor.
    """
    assert image.min() >= 0 and image.max(
    ) <= 1, f'Image pixel values must be in the range [0, 1], got [{image.min()}, {image.max()}]'
    pil_image = transforms.ToPILImage()(image)  # convert to PIL image
    # Create a BytesIO object and save the PIL image as JPEG to this object
    buffer = io.BytesIO()
    pil_image.save(buffer, format='JPEG', quality=quality)
    # Load the JPEG image from the BytesIO object and convert back to a PyTorch tensor
    buffer.seek(0)
    compressed_image = Image.open(buffer)
    tensor_image = transforms.ToTensor()(compressed_image)
    return tensor_image


def webp_compress(image: torch.Tensor, quality: int) -> torch.Tensor:
    """
    Compress a PyTorch image using WebP compression and return as a PyTorch tensor.

    Parameters:
        image (torch.Tensor): The input image tensor of shape 3xhxw.
        quality (int): The WebP quality factor.

    Returns:
        torch.Tensor: The compressed image as a PyTorch tensor.
    """
    image = torch.clamp(image, 0, 1)  # clamp the pixel values to [0, 1]
    pil_image = transforms.ToPILImage()(image)  # convert to PIL image
    # Create a BytesIO object and save the PIL image as WebP to this object
    buffer = io.BytesIO()
    pil_image.save(buffer, format='WebP', quality=quality)
    # Load the WebP image from the BytesIO object and convert back to a PyTorch tensor
    buffer.seek(0)
    compressed_image = Image.open(buffer)
    tensor_image = transforms.ToTensor()(compressed_image)
    return tensor_image


def median_filter(images: torch.Tensor, kernel_size: int) -> torch.Tensor:
    """
    Apply a median filter to a batch of images.

    Parameters:
        images (torch.Tensor): The input images tensor of shape BxCxHxW.
        kernel_size (int): The size of the median filter kernel.

    Returns:
        torch.Tensor: The filtered images.
    """
    # Ensure the kernel size is odd
    if kernel_size % 2 == 0:
        raise ValueError("Kernel size must be odd.")
    # Compute the padding size
    padding = kernel_size // 2
    # Pad the images
    images_padded = torch.nn.functional.pad(
        images, (padding, padding, padding, padding))
    # Extract local blocks from the images
    blocks = images_padded.unfold(2, kernel_size, 1).unfold(
        3, kernel_size, 1)  # BxCxHxWxKxK
    # Compute the median of each block
    medians = blocks.median(dim=-1).values.median(dim=-1).values  # BxCxHxW
    return medians


def create_diff_img(img1, img2):
    """
    Create a difference image between two images.

    Parameters:
        img1 (torch.Tensor): The first image tensor of shape 3xHxW.
        img2 (torch.Tensor): The second image tensor of shape 3xHxW.

    Returns:
        torch.Tensor: The difference image tensor of shape 3xHxW.
    """
    diff = img1 - img2
    # diff = 0.5 + 10*(img1 - img2)
    # normalize the difference image
    diff = (diff - diff.min()) / ((diff.max() - diff.min()) + 1e-6)
    diff = 2*torch.abs(diff - 0.5)
    # diff = 20*torch.abs(diff)
    return diff.clamp(0, 1)


if __name__ == '__main__':
    # Example usage: python src/utils/image.py
    x = torch.rand(3, 256, 256)  # random image
    x_jpeg = jpeg_compress(x, 80)  # compress
    x_webp = webp_compress(x, 80)  # compress

    print(x[0, 0:5, 0:5])  # print the first 5x5 pixels of the first channel
    # print the first 5x5 pixels of the first channel
    print(x_jpeg[0, 0:5, 0:5])
    # print the first 5x5 pixels of the first channel
    print(x_webp[0, 0:5, 0:5])
