import torch.nn as nn
from torchvision.models import resnet18

def get_model(num_channels, num_classes=10):
    # For Exp 1 & 2
    model = resnet18(weights=None, num_classes=num_classes)
    model.conv1 = nn.Conv2d(num_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model

def get_cifar10_sota_model(num_classes=10, device='cpu'):
    model = resnet18(weights=None, num_classes=num_classes)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    return model.to(device)

class CocoModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        base = resnet18(weights='IMAGENET1K_V1')
        self.backbone = nn.Sequential(*list(base.children())[:-1])
        for p in self.backbone.parameters():
            p.requires_grad = False
        self.head = nn.Linear(base.fc.in_features, num_classes)

    def forward(self, x):
        with torch.no_grad():
            feats = self.backbone(x)
        feats = feats.view(feats.size(0), -1)
        return self.head(feats)

def get_coco_sota_model(num_classes, device='cpu'):
    return CocoModel(num_classes).to(device)