import torchvision.models.detection as detection
import torchvision
import torch
from .DeTR_files.detr import PostProcess as DETR_PostProcessor
from .DeTR import DETRModel
    
def get_od_model(config):
    if config.architecture == "DETR-R50":
        model = DETRModel(config=config)
        postprocessor = DETR_PostProcessor()
        return model,postprocessor
    
    elif config.architecture == "FasterRCNN":
        #no_classes =bg+no_of_classes
        no_classes = config.no_classes+1
        model = detection.fasterrcnn_resnet50_fpn(pretrained=True)
        model.backbone.body.conv1 = torch.nn.Conv2d(config.in_channels, 64, 
                                                    kernel_size=(7, 7), stride=(2, 2), 
                                                    padding=(3, 3), bias=False)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, no_classes)
        model.transform = torchvision.models.detection.transform.GeneralizedRCNNTransform(
            min_size=256,
            max_size=256,
            image_mean=[0,0,0,0],
            image_std=[1,1,1,1]
        )

    elif config.architecture == "RetinaNet":
        #no_classes =bg+no_of_classes
        no_classes = config.no_classes+1
        model = detection.retinanet_resnet50_fpn(num_classes = no_classes,pretrained_backbone =True)
        model.backbone.body.conv1 = torch.nn.Conv2d(config.in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        model.transform = torchvision.models.detection.transform.GeneralizedRCNNTransform(
            min_size=256,
            max_size=256,
            image_mean=[0,0,0,0],
            image_std=[1,1,1,1]
        )

    else:
        raise ValueError(f"Unsupported architecture: {config.architecture}")
    
    return model