import torch
import torch.nn as nn

class DETRModel(nn.Module):
    def __init__(self,config):
        super(DETRModel,self).__init__()
        self.num_classes = config.no_classes+1
        self.num_queries = config.no_classes*5
        
        self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
        self.in_features = self.model.class_embed.in_features
        self.model.backbone[0].body.conv1 = nn.Conv2d(config['in_channels'],64,kernel_size=(7,7),stride=(2,2),padding=(3,3), bias=False)
        
        self.model.class_embed = nn.Linear(in_features=self.in_features,out_features=self.num_classes)
        self.model.num_queries = self.num_queries
        
    def forward(self,images):
        return self.model(images)