import torch.nn as nn
import torchvision.models as models

class ResNetWithHead(nn.Module):
    def __init__(self, backbone_name='resnet18', latent_dim=128, num_classes=2):
        super().__init__()
        if backbone_name == 'resnet18':
            self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        elif backbone_name == 'resnet34':
            self.backbone = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        else:
            raise ValueError(f"Unsupported backbone: {backbone_name}")

        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.projector = nn.Linear(in_features, latent_dim)
        self.classifier = nn.Linear(latent_dim, num_classes)

    def forward(self, x):
        features = self.backbone(x)
        latent = self.projector(features)
        return self.classifier(latent)
