import torch
from torch import nn
# from .resnet import ResNet50, ResNet18
import pdb
from utils import utils
import torchvision.models as models

class BatchNorm1dNoBias(nn.BatchNorm1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bias.requires_grad = False


class Encoder(nn.Module):
    def __init__(self, arch, proj_dim, input_size):
        super().__init__()
        
        if arch == 'resnet50':
            self.convnet = models.resnet50(pretrained = False)
            self.encoder_dim = 2048
        elif arch == 'resnet18':
            self.convnet = models.resnet18(pretrained = False)
            self.encoder_dim = 512

        self.convnet.fc = nn.Identity()
        
        num_params = sum(p.numel() for p in self.convnet.parameters() if p.requires_grad)

        print(f'======> Encoder: output dim {self.encoder_dim} | {num_params/1e6:.3f}M parameters')

        self.proj_dim = proj_dim
        
        self.projection = nn.Sequential(
            nn.Linear(self.encoder_dim, self.encoder_dim, bias=False),
#             nn.BatchNorm1d(self.encoder_dim),
            nn.ReLU(),
            nn.Linear(self.encoder_dim, self.proj_dim, bias=False),
#             BatchNorm1dNoBias(self.proj_dim),
        )

    def forward(self, x):
        h = self.convnet(x)
        z = self.projection(h)
        return h, z