import torchvision
import torch
import torch.nn as nn
from torch.nn import init

def weight_init_kaiming(m):
    class_names = m.__class__.__name__
    if class_names.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif class_names.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data)

class ResNet(nn.Module):
    def __init__(self, pretrained=True) -> None:
        super().__init__()
        self.base = torchvision.models.resnet50(pretrained=pretrained)
        self.base.avgpool =  nn.AdaptiveAvgPool2d((1,1))
        self.base.fc = nn.Linear(512 * 4, 200)
        self.base.fc.apply(weight_init_kaiming)
        
    def forward(self, x):
        return self.base(x)

class ResNetExtractor(ResNet):
    def __init__(self, pretrained=True) -> None:
        super().__init__(pretrained)
        
    def forward(self, x):
        x = self.base.conv1(x)
        x = self.base.bn1(x)
        x = self.base.relu(x)
        x = self.base.maxpool(x)

        x = self.base.layer1(x)
        x = self.base.layer2(x)
        x = self.base.layer3(x)
        x = self.base.layer4(x)

        x = self.base.avgpool(x)
        x = torch.flatten(x, 1)
        
        return x