from torch import nn
from torchvision.models import inception_v3
import torch.nn.functional as F


class Inception(nn.Module):
    def __init__(self):
        super().__init__()
        self.inception_v3 = inception_v3(weights='Inception_V3_Weights.DEFAULT', transform_input=True, aux_logits=False)
        self.inception_v3.fc = nn.Identity()
        
    def forward(self, x):
        x = F.interpolate(x, (299, 299), mode='bilinear', align_corners=False)
        # print(x.size())
        x = self.inception_v3(x)
        # print(type(x))
        # x = nn.functional.normalize(aux)
        return x