import torch
import torch.nn as nn
from torchvision.models import vgg16
import torchvision.transforms as TF

from .encoder import Encoder

# from ..resizer import pil_resize

class VGGRandom(nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        vgg = vgg16(weights=None)  # ランダム初期化
        self.features = vgg.features
        self.avgpool = vgg.avgpool
        self.fc1 = vgg.classifier[0]             # Linear(512*7*7, 4096)
        self.relu1 = vgg.classifier[1]           # ReLU
        self.fc2 = nn.Linear(4096, 64)           # Linear(4096, 64) 

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x

class VGGRandomEncoder(Encoder):
    def setup(self):
        self.model = VGGRandom()
        self.model.eval()

    def transform(self, image):

        # Convert RGB
        image = image.convert("RGB")

        # Pad if necessary (before ToTensor)
        if image.width < 32 or image.height < 32:
            pad_left = max((32 - image.width) // 2, 0)
            pad_top = max((32 - image.height) // 2, 0)
            pad_right = max(32 - image.width - pad_left, 0)
            pad_bottom = max(32 - image.height - pad_top, 0)
            image = TF.functional.pad(image, (pad_left, pad_top, pad_right, pad_bottom), fill=0)
        
        # Convert to tensor and normalize for RGB
        image = TF.ToTensor()(image)
        image = TF.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(image)
        
        return image
