import torch
import torch.nn as nn
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.faster_rcnn import FasterRCNN_MobileNet_V3_Large_FPN_Weights
from torchvision.models.detection.rpn import AnchorGenerator

class FastRCNN(nn.Module):
    def __init__(self, num_classes):
        """
        Initializes the FastRCNN model.

        Args:
            num_classes (int): Number of classes for classification (including background).
        """
        super(FastRCNN, self).__init__()

        # model
        self.model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT)
        in_features = self.model.roi_heads.box_predictor.cls_score.in_features
        self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

        # anchor_generator = AnchorGenerator(
        #     sizes=((32, 64, 128, 256, 512),),
        #     aspect_ratios=((0.5, 1.0, 2.0),)
        # )
        # self.model.rpn.anchor_generator = anchor_generator

    def forward(self, x, targets=None):
        """
        Forward pass through the model.

        Args:
            images (list[torch.Tensor]): List of images as tensors.
            targets (list[dict], optional): List of target dictionaries (for training).

        Returns:
            If training: A dictionary of losses.
            If evaluating: Predictions for the input images.
        """
        return self.model(x, targets)

if __name__ == "__main__":
    num_classes = 10
    model = FastRCNN(num_classes=num_classes)

    images = [torch.randn(3, 224, 224) for _ in range(2)]

    targets = [
        {
            "boxes": torch.tensor([[50, 50, 150, 150]], dtype=torch.float32),  # Bounding box
            "labels": torch.tensor([1], dtype=torch.int64),  # Class label
        },
        {
            "boxes": torch.tensor([[30, 30, 100, 100]], dtype=torch.float32),
            "labels": torch.tensor([2], dtype=torch.int64),
        },
    ]

    model.train()
    output = model(images, targets)
    print("Training Output (Losses):", output)

    model.eval()
    with torch.no_grad():
        predictions = model(images)
        print("Inference Output (Predictions):", predictions)