import numpy as np
from collections import OrderedDict
import torchvision
import torch
import torch.nn as nn
import torchvision.models.resnet as resnet

class Encoder(nn.Module):
    def __init__(self,
                 size: int=50,
                 pretrained: bool=False,
                 pool_features: bool=False,
                 **kwargs,
                 ):
        super().__init__()
        full_model = {18 : resnet.resnet18,
                      34 : resnet.resnet34,
                      50 : resnet.resnet50,
                      101 : resnet.resnet101,
                     }[size](pretrained)
        resnet_cout = 512 if size in (18, 34) else 2048

        # remove pool and linear
        modules = list(full_model.children())[:-2]

        if pool_features:
            modules.append(nn.AdaptiveAvgPool2d(1))
            self.output_shape = (resnet_cout, 1, 1)
        else:
            self.output_shape = (resnet_cout, 7, 7)

        self.layers = nn.Sequential(*modules)

    def forward(self, img):
        return self.layers(img)

