from defense.base_defense import BaseDefense
from kornia.augmentation import RandomJPEG
import torch

class JPEGDefense(BaseDefense):
    """
    JPEG Defense class for applying JPEG compression to images.
    """
    def __init__(self, quality=75.0, device=None, iterations=1):
        """
        Initialize the JPEG defense with a specified quality and device.
        
        Args:
            quality (int): Quality factor for JPEG compression (default: 75).
            device (torch.device or str): Device to move the model to (default: None).
            iterations (int): Number of iterations to apply the defense (default: 1).
        """       
        super(JPEGDefense, self).__init__(device,iterations)
        if not isinstance(quality,float):
            try:
                quality = float(quality)
            except ValueError:
                print("Can't convert input quality to float. Setting quality to standard value of 75.")
                quality = 75.0

        if quality < 1 or quality > 100:
            quality = max(1, min(quality, 100))
            print("Quality factor out of range. Clamping to [0, 100].")
        
        self.quality = quality
        self.aug = RandomJPEG(jpeg_quality = torch.tensor([quality,quality]).to(self.device), p = 1.0, keepdim = True).to(self.device)
        self.device_attributes.append('aug')

    def _defense(self, x):
        """
        Apply the defense to the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).
        
        Returns:
            torch.Tensor: iterations times compressed and decompressed tensor.
        """
        for _ in range(self.iterations):
            x = self.aug(x)
        return x