import torch
import torchvision
from torchvision.models.detection import RetinaNet_ResNet50_FPN_Weights
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from functools import partial

class RetinaNetModel(torch.nn.Module):
    def __init__(self, num_classes):
        super(RetinaNetModel, self).__init__()
        self.model = torchvision.models.detection.retinanet_resnet50_fpn(
            weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT
        )
        num_anchors = self.model.head.classification_head.num_anchors

        self.model.head.classification_head = RetinaNetClassificationHead(
            in_channels=256,
            num_anchors=num_anchors,
            num_classes=num_classes,
            norm_layer=partial(torch.nn.GroupNorm, 32)
        )

    def forward(self, x, targets=None):
        return self.model(x, targets)
