import torch
import torch.nn.functional as F
from torchvision import transforms

class DifferentiableTransform:
    def __init__(self, n_px):
        self.n_px = n_px
        self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)  # Adjust to (3, 1, 1) for broadcasting
        self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)  # Adjust to (3, 1, 1) for broadcasting

    def forward(self, image_tensor):
        assert isinstance(image_tensor, torch.Tensor), "Input must be a PyTorch tensor"
        image_tensor = F.interpolate(image_tensor.unsqueeze(0), size=(self.n_px, self.n_px), mode='bicubic', align_corners=False).squeeze(0)
        image_tensor = self.center_crop(image_tensor)
        # import pdb; pdb.set_trace()
        image_tensor = (image_tensor - self.mean.to(image_tensor.device)) / self.std.to(image_tensor.device)
        
        return image_tensor

    def center_crop(self, image_tensor):
        _, h, w = image_tensor.shape
        top = (h - self.n_px) // 2
        left = (w - self.n_px) // 2
        return image_tensor[:, top:top + self.n_px, left:left + self.n_px]