from defense.base_defense import BaseDefense
from compressai.zoo import bmshj2018_hyperprior,mbt2018_mean
from torchvision import transforms
import torch

class HyperpriorDefense(BaseDefense):
    """
    ELIC Defense class for applying ELIC to images.
    """
    def __init__(self, quality = 1, device=None, iterations=1):
        """
        Initialize the ELIC defense with a specified weights and device.
        
        Args:
            weights (str): model weights to use for HiFiC (default: '0016').
            device (torch.device or str): Device to move the model to (default: None).
        """       
        super(HyperpriorDefense, self).__init__(device,iterations)
        try:
            quality = int(quality)
        except ValueError:
            raise ValueError(f"Quality must be an integer, got {quality} instead.")
        self.model = mbt2018_mean(quality=quality, pretrained=True)
        self.model = self.model.to(self.device)
        self.model.train()
        print(f'hyperprior training mode is {self.model.training}')
        print(f'self gaussian conditional training mode is {self.model.gaussian_conditional.training}')
        print(f'self entropy_bottleneck training mode is {self.model.entropy_bottleneck.training}')
        print(f'self g_a training mode is {self.model.g_a.training}')
        print(f'self g_s training mode is {self.model.g_s.training}')
        print(f'self h_a training mode is {self.model.h_a.training}')
        print(f'self h_s training mode is {self.model.h_s.training}')
        self.device_attributes.append('model')

    def _defense(self, x):
        """
        Apply the defense to the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).
        
        Returns:
            torch.Tensor: iterations times compressed and decompressed tensor.
        """
        x = transforms.Resize((256,256))(x)
        for _ in range(self.iterations):
            x = self.model(x)['x_hat']
        x = transforms.Resize((224,224))(x)
        return x