import torch
import torch.nn as nn
from torchvision.models import vgg16
import torchvision.transforms as TF

from .encoder import Encoder

# from ..resizer import pil_resize

class SmallCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 32x32 -> 32x32
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), # 32x32 -> 32x32
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                              # 32x32 -> 16x16
            nn.Flatten(),
            nn.Linear(64 * 16 * 16, 128),                 
            nn.ReLU(inplace=True),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        # Extract features before the final linear layer
        x = self.net[0](x)
        x = self.net[1](x)
        x = self.net[2](x)
        x = self.net[3](x)
        x = self.net[4](x)
        x = self.net[5](x)
        x = self.net[6](x)
        return x

class MNISTEncoder(Encoder):
    def setup(self):
        self.model = SmallCNN()
        self.model.load_state_dict(torch.load("./cas_eval/mnist_classifier.pth"))
        self.model.eval()

    def transform(self, image):

        # Convert RGB
        image = image.convert("RGB")

        # Pad if necessary (before ToTensor)
        if image.width < 32 or image.height < 32:
            pad_left = max((32 - image.width) // 2, 0)
            pad_top = max((32 - image.height) // 2, 0)
            pad_right = max(32 - image.width - pad_left, 0)
            pad_bottom = max(32 - image.height - pad_top, 0)
            image = TF.functional.pad(image, (pad_left, pad_top, pad_right, pad_bottom), fill=0)
        
        # Convert to tensor and normalize for RGB
        image = TF.ToTensor()(image)
        image = TF.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(image)
        
        return image