
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import time
from torch import nn, optim
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader

class MnistResNet(ResNet):
    def __init__(self, layers):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

    def forward(self, x):
        #return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)
        return super(MnistResNet, self).forward(x)
