from timm import create_model 
from .ema import ExponentialMovingAverage
import torch 
import torchvision
from torchvision import transforms 

class RN50Classifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.clf = torch.nn.DataParallel(torchvision.models.resnet50(pretrained=False)).load_state_dict(
            torch.load("workdirs/checkpoint.pth.tar")['state_dict'])
    
    def forward(self, x):  
        return self.clf(x)
    
    def state_dict(self):
        return self.clf.state_dict()
    
    def load_state_dict(self, sd):
        self.clf.load_state_dict(sd)
