import torch
import torch.nn as nn
import torch.nn.functional as F
class ImageClassifier(nn.Module):
    def __init__(self, params):
        super(ImageClassifier, self).__init__()
        self.params = params

        import torchvision.models as models
        self.base_model = base_model = models.resnet34(pretrained=True)
        self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
        self.feat_out_channels = [64, 64, 128, 256, 512]
        self.num_class = 3

        self.linear = nn.Sequential(
            nn.Linear(512, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, self.num_class),
            # nn.Softmax(dim=1)
        )

    def forward(self, x):
        feature = x
        i = 1
        for k, v in self.base_model._modules.items():
            if 'fc' in k or 'avgpool' in k:
                continue
            feature = v(feature)
        
        feature = F.adaptive_avg_pool2d(feature, (1, 1)).squeeze(2).squeeze(2)

        feature = self.linear(feature)
        return feature