import math
import warnings
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torch import nn, Tensor

from ..utils.transform import GeneralizedRCNNTransform

from torchvision.ops import boxes as box_ops, generalized_box_iou_loss, misc as misc_nn_ops, sigmoid_focal_loss
from torchvision.ops.feature_pyramid_network import LastLevelP6P7
from torchvision.transforms._presets import ObjectDetection
from torchvision.utils import _log_api_usage_once
from torchvision.models._api import Weights, WeightsEnum
from torchvision.models._meta import _COCO_CATEGORIES
from torchvision.models._utils import _ovewrite_value_param
from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.models.detection import _utils as det_utils
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from torchvision.ops import box_iou

__all__ = [
    "FCOS",
    "FCOS_ResNet50_FPN_Weights",
    "fcos_resnet50_fpn",
]

dataset_norms = {
    "default": {
        "mean": [0.485, 0.456, 0.406],
        "std": [0.229, 0.224, 0.225],
    },
}

def attack_loss(
    pred_logits: torch.Tensor,      # (B, F, C), raw logits
    pred_boxes: torch.Tensor,       # (B, F, 4), xyxy (must match box_iou format)
    gt_boxes: torch.Tensor,         # (B, M, 4), xyxy
    gt_classes: torch.Tensor,       # (B, M), class indices
    gt_poison_masks: torch.Tensor,  # (B, M), 0/1 mask
    iou_threshold: float = 0.5
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Returns:
      total_loss: summed softplus-barrier loss over all (anchor, GT) hits
      num_hits:   total number of poisoned overlaps considered
    """
    B, F, C = pred_logits.shape
    device = pred_logits.device

    all_losses = []
    total_hits = 0

    TAU = 0.2
    EPS = 1e-7
    logit_tau = math.log(TAU / (1 - TAU))

    for b in range(B):

        # 0) skip if no poisoned GTs
        poisoned_idx = (gt_poison_masks[b] == 1).nonzero(as_tuple=True)[0]  # (M′,)
        if poisoned_idx.numel() == 0:
            continue

        p_gt_boxes   = gt_boxes[b][poisoned_idx]     # (M′, 4)
        p_gt_classes = gt_classes[b][poisoned_idx]   # (M′,)
        M_p = p_gt_boxes.size(0)

        # 1) IoU matrix  (F, M′)
        iou_fm = box_iou(pred_boxes[b], p_gt_boxes)

        # 2) threshold on IoU (poison mask already applied)
        hit_mask = iou_fm > iou_threshold            # (F, M′)
        if not hit_mask.any():
            continue

        # 3) gather raw logits for those hits  – identical to original
        class_mat = p_gt_classes.unsqueeze(0).expand(F, M_p)      # (F, M′)
        raw = pred_logits[b].gather(1, class_mat)                 # (F, M′)
        selected = raw[hit_mask]                                  # (K,)

        prob_cent_tau = 1 / (1 + torch.exp(-(selected - logit_tau)))
        loss = -torch.log(1 - prob_cent_tau.clamp(min=EPS, max=1 - EPS))

        all_losses.append(loss.sum())
        total_hits += hit_mask.sum().item()


    # no poisoned overlaps
    if total_hits == 0:
        return torch.tensor(0.0, device=device), torch.tensor(0, device=device)
    
    total_loss = torch.stack(all_losses).sum()
    return total_loss, torch.tensor(total_hits, device=device)

class FCOSHead(nn.Module):
    """
    A regression and classification head for use in FCOS.

    Args:
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
        num_classes (int): number of classes to be predicted
        num_convs (Optional[int]): number of conv layer of head. Default: 4.
    """

    __annotations__ = {
        "box_coder": det_utils.BoxLinearCoder,
    }

    # --- MODIFIED ---
    def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4, lambda_attack: int = 1) -> None:
        super().__init__()
        self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)
        self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs)
        self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs)

        # --- MODIFIED ---
        self.lambda_attack = lambda_attack

    def compute_loss(
        self,
        targets: list[dict[str, Tensor]],
        new_targets: list[dict[str, Tensor]],  # <-- Use new_targets instead of targets
        head_outputs: dict[str, Tensor],
        anchors: list[Tensor],
        matched_idxs: list[Tensor],
    ) -> dict[str, Tensor]:

        cls_logits = head_outputs["cls_logits"]  # [N, HWA, C]
        bbox_regression = head_outputs["bbox_regression"]  # [N, HWA, 4]
        bbox_ctrness = head_outputs["bbox_ctrness"]  # [N, HWA, 1]

        all_gt_classes_targets = []
        all_gt_boxes_targets = []
        for targets_per_image, matched_idxs_per_image in zip(new_targets, matched_idxs): # <-- Use new_targets instead of targets
            if len(targets_per_image["labels"]) == 0:
                gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
                gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
            else:
                gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
                gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
            gt_classes_targets[matched_idxs_per_image < 0] = -1  # background
            all_gt_classes_targets.append(gt_classes_targets)
            all_gt_boxes_targets.append(gt_boxes_targets)

        # List[Tensor] to Tensor conversion of  `all_gt_boxes_target`, `all_gt_classes_targets` and `anchors`
        all_gt_boxes_targets, all_gt_classes_targets, anchors = (
            torch.stack(all_gt_boxes_targets),
            torch.stack(all_gt_classes_targets),
            torch.stack(anchors),
        )

        # compute foregroud
        foregroud_mask = all_gt_classes_targets >= 0
        num_foreground = foregroud_mask.sum().item()

        # classification loss
        gt_classes_targets = torch.zeros_like(cls_logits)

        gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0
        loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum")

        # amp issue: pred_boxes need to convert float
        pred_boxes = self.box_coder.decode(bbox_regression, anchors)

        # regression loss: GIoU loss
        loss_bbox_reg = generalized_box_iou_loss(
            pred_boxes[foregroud_mask],
            all_gt_boxes_targets[foregroud_mask],
            reduction="sum",
        )

        # ctrness loss
        bbox_reg_targets = self.box_coder.encode(anchors, all_gt_boxes_targets)

        if len(bbox_reg_targets) == 0:
            gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
        else:
            left_right = bbox_reg_targets[:, :, [0, 2]]
            top_bottom = bbox_reg_targets[:, :, [1, 3]]
            gt_ctrness_targets = torch.sqrt(
                (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
                * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
            )
        pred_centerness = bbox_ctrness.squeeze(dim=2)
        loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
            pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"
        )

        # --- MODIFIED --- Added attack loss computation
        # Check if lambda_attack is 0 
        if self.lambda_attack != 0:
            target_boxes = [target["boxes"] for target in targets]
            target_classes = [target["labels"] for target in targets]
            poison_mask = [target["poison_masks"] for target in targets]

            attack_total_loss, num_poisoned = attack_loss(
                cls_logits,
                pred_boxes,
                target_boxes,
                target_classes,
                poison_mask
            )

            attack_total_loss = attack_total_loss * self.lambda_attack 

            # print(f'Number of poisoned overlaps: {num_poisoned.item()}')

        else:
            attack_total_loss, num_poisoned = torch.tensor(0.0, device=cls_logits.device), torch.tensor(0, device=cls_logits.device)


        # print(f'Num foreground: {num_foreground}')
        # print(f'Num background: {cls_logits.numel() - num_foreground}')

        # Print the losses for debugging
        loss_dict = {
            "classification": loss_cls / max(1, num_foreground),
            "bbox_regression": loss_bbox_reg / max(1, num_foreground),
            "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground),
            "attack_loss": attack_total_loss / max(1, num_poisoned),
        }

        return loss_dict

    
    def forward(self, x: list[Tensor]) -> dict[str, Tensor]:
        cls_logits = self.classification_head(x)
        bbox_regression, bbox_ctrness = self.regression_head(x)
        return {
            "cls_logits": cls_logits,
            "bbox_regression": bbox_regression,
            "bbox_ctrness": bbox_ctrness,
        }


class FCOSClassificationHead(nn.Module):
    """
    A classification head for use in FCOS.

    Args:
        in_channels (int): number of channels of the input feature.
        num_anchors (int): number of anchors to be predicted.
        num_classes (int): number of classes to be predicted.
        num_convs (Optional[int]): number of conv layer. Default: 4.
        prior_probability (Optional[float]): probability of prior. Default: 0.01.
        norm_layer: Module specifying the normalization layer to use.
    """

    def __init__(
        self,
        in_channels: int,
        num_anchors: int,
        num_classes: int,
        num_convs: int = 4,
        prior_probability: float = 0.01,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()

        self.num_classes = num_classes
        self.num_anchors = num_anchors

        if norm_layer is None:
            norm_layer = partial(nn.GroupNorm, 32)

        conv = []
        for _ in range(num_convs):
            conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
            conv.append(norm_layer(in_channels))
            conv.append(nn.ReLU())
        self.conv = nn.Sequential(*conv)

        for layer in self.conv.children():
            if isinstance(layer, nn.Conv2d):
                torch.nn.init.normal_(layer.weight, std=0.01)
                torch.nn.init.constant_(layer.bias, 0)

        self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
        torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
        torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))

    def forward(self, x: List[Tensor]) -> Tensor:
        all_cls_logits = []

        for features in x:
            cls_logits = self.conv(features)
            cls_logits = self.cls_logits(cls_logits)

            # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
            N, _, H, W = cls_logits.shape
            cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
            cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
            cls_logits = cls_logits.reshape(N, -1, self.num_classes)  # Size=(N, HWA, 4)

            all_cls_logits.append(cls_logits)

        return torch.cat(all_cls_logits, dim=1)


class FCOSRegressionHead(nn.Module):
    """
    A regression head for use in FCOS, which combines regression branch and center-ness branch.
    This can obtain better performance.

    Reference: `FCOS: A simple and strong anchor-free object detector <https://arxiv.org/abs/2006.09214>`_.

    Args:
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
        num_convs (Optional[int]): number of conv layer. Default: 4.
        norm_layer: Module specifying the normalization layer to use.
    """

    def __init__(
        self,
        in_channels: int,
        num_anchors: int,
        num_convs: int = 4,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ):
        super().__init__()

        if norm_layer is None:
            norm_layer = partial(nn.GroupNorm, 32)

        conv = []
        for _ in range(num_convs):
            conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
            conv.append(norm_layer(in_channels))
            conv.append(nn.ReLU())
        self.conv = nn.Sequential(*conv)

        self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
        self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1)
        for layer in [self.bbox_reg, self.bbox_ctrness]:
            torch.nn.init.normal_(layer.weight, std=0.01)
            torch.nn.init.zeros_(layer.bias)

        for layer in self.conv.children():
            if isinstance(layer, nn.Conv2d):
                torch.nn.init.normal_(layer.weight, std=0.01)
                torch.nn.init.zeros_(layer.bias)

    def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]:
        all_bbox_regression = []
        all_bbox_ctrness = []

        for features in x:
            bbox_feature = self.conv(features)
            bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature))
            bbox_ctrness = self.bbox_ctrness(bbox_feature)

            # permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
            N, _, H, W = bbox_regression.shape
            bbox_regression = bbox_regression.view(N, -1, 4, H, W)
            bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
            bbox_regression = bbox_regression.reshape(N, -1, 4)  # Size=(N, HWA, 4)
            all_bbox_regression.append(bbox_regression)

            # permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1).
            bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W)
            bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2)
            bbox_ctrness = bbox_ctrness.reshape(N, -1, 1)
            all_bbox_ctrness.append(bbox_ctrness)

        return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1)


class FCOS(nn.Module):
    """
    Implements FCOS.
    Modified version of FCOS from torchvision.models.detection.fcos.py
    Changes are indicated by comments
    """

    __annotations__ = {
        "box_coder": det_utils.BoxLinearCoder,
    }

    def __init__(
        self,
        backbone: nn.Module,
        num_classes: int,
        # transforms
        train_transform,
        test_transform,
        # Anchor parameters
        anchor_generator: Optional[AnchorGenerator] = None,
        head: Optional[nn.Module] = None,
        center_sampling_radius: float = 1.5,
        score_thresh: float = 0.2,
        nms_thresh: float = 0.6,
        detections_per_img: int = 100,
        topk_candidates: int = 1000,
        lambda_attack: int = 1, # <-- Added lambda_attack
        **kwargs,
    ):
        super().__init__()
        _log_api_usage_once(self)

        if not hasattr(backbone, "out_channels"):
            raise ValueError(
                "backbone should contain an attribute out_channels "
                "specifying the number of output channels (assumed to be the "
                "same for all the levels)"
            )
        self.backbone = backbone

        if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
            raise TypeError(
                f"anchor_generator should be of type AnchorGenerator or None, instead  got {type(anchor_generator)}"
            )

        if anchor_generator is None:
            anchor_sizes = ((8,), (16,), (32,), (64,), (128,))  # equal to strides of multi-level feature map
            aspect_ratios = ((1.0,),) * len(anchor_sizes)  # set only one anchor
            anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
        self.anchor_generator = anchor_generator
        if self.anchor_generator.num_anchors_per_location()[0] != 1:
            raise ValueError(
                f"anchor_generator.num_anchors_per_location()[0] should be 1 instead of {anchor_generator.num_anchors_per_location()[0]}"
            )

        if head is None:
            head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes, lambda_attack=lambda_attack) # <-- Added lambda_attack
        self.head = head

        self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True)

        self.center_sampling_radius = center_sampling_radius
        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.detections_per_img = detections_per_img
        self.topk_candidates = topk_candidates

        # transforms
        self.train_transform = train_transform
        self.test_transform = test_transform

        # --- MODIFIED ---
        # Added loss type and hyperparameter
        self.lambda_attack = lambda_attack

        # used only on torchscript mode
        self._has_warned = False

    @torch.jit.unused
    def eager_outputs(
        self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
    ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
        if self.training:
            return losses

        return detections

    def compute_loss(
        self,
        targets: list[dict[str, Tensor]],
        head_outputs: dict[str, Tensor],
        anchors: list[Tensor],
        num_anchors_per_level: list[int],
    ) -> dict[str, Tensor]:
        
        # -- Modified to use modified targets for matching --
        new_targets = []
        for t in targets:
            pm = t["poison_masks"]
            tl = t["target_labels"]

            # 1) build a keep-mask: drop boxes where pm==1 AND tl==0
            keep = ~(pm.bool() & (tl == 0))

            # 2) remap labels: where pm==1 use tl (unless tl==0, but those are already dropped)
            remapped = torch.where(tl == 0, torch.tensor(-1, device=tl.device), tl)
            labels = torch.where(pm.bool(), remapped, t["labels"])

            new_boxes = t["boxes"][keep]
            new_labels = labels[keep]

            # 3) check if any boxes are left
            if new_boxes.numel() == 0:
                new_boxes = torch.empty((0, 4), dtype=t["boxes"].dtype, device=t["boxes"].device)
                new_labels = torch.empty((0,), dtype=t["labels"].dtype, device=t["labels"].device)

            # 3) create a new target dict with everything filtered/updated
            new_targets.append({
                "boxes": new_boxes,
                "labels": new_labels,
            })
        
        matched_idxs = []
        for anchors_per_image, targets_per_image in zip(anchors, new_targets):
            if targets_per_image["boxes"].numel() == 0:
                matched_idxs.append(
                    torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
                )
                continue

            gt_boxes = targets_per_image["boxes"]
            gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2  # Nx2
            anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2  # N
            anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0]
            # center sampling: anchor point must be close enough to gt center.
            pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
                dim=2
            ).values < self.center_sampling_radius * anchor_sizes[:, None]
            # compute pairwise distance between N points and M boxes
            x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1)  # (N, 1)
            x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2)  # (1, M)
            pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2)  # (N, M)

            # anchor point must be inside gt
            pairwise_match &= pairwise_dist.min(dim=2).values > 0

            # each anchor is only responsible for certain scale range.
            lower_bound = anchor_sizes * 4
            lower_bound[: num_anchors_per_level[0]] = 0
            upper_bound = anchor_sizes * 8
            upper_bound[-num_anchors_per_level[-1] :] = float("inf")
            pairwise_dist = pairwise_dist.max(dim=2).values
            pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None])

            # match the GT box with minimum area, if there are multiple GT matches
            gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1])  # N
            pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
            min_values, matched_idx = pairwise_match.max(dim=1)  # R, per-anchor match
            matched_idx[min_values < 1e-5] = -1  # unmatched anchors are assigned -1

            matched_idxs.append(matched_idx)

        return self.head.compute_loss(targets, new_targets, head_outputs, anchors, matched_idxs)

    def postprocess_detections(
        self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]]
    ) -> List[Dict[str, Tensor]]:
        class_logits = head_outputs["cls_logits"]
        box_regression = head_outputs["bbox_regression"]
        box_ctrness = head_outputs["bbox_ctrness"]

        num_images = len(image_shapes)

        detections: List[Dict[str, Tensor]] = []

        for index in range(num_images):
            box_regression_per_image = [br[index] for br in box_regression]
            logits_per_image = [cl[index] for cl in class_logits]
            box_ctrness_per_image = [bc[index] for bc in box_ctrness]
            anchors_per_image, image_shape = anchors[index], image_shapes[index]

            image_boxes = []
            image_scores = []
            image_labels = []

            for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip(
                box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image
            ):
                num_classes = logits_per_level.shape[-1]

                # remove low scoring boxes
                scores_per_level = torch.sqrt(
                    torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level)
                ).flatten()

                keep_idxs = scores_per_level > self.score_thresh
                scores_per_level = scores_per_level[keep_idxs]
                topk_idxs = torch.where(keep_idxs)[0]

                # keep only topk scoring predictions
                num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
                scores_per_level, idxs = scores_per_level.topk(num_topk)
                topk_idxs = topk_idxs[idxs]

                anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
                labels_per_level = topk_idxs % num_classes

                boxes_per_level = self.box_coder.decode(
                    box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
                )
                boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)

                # --- Modified ---
                # Remove boxes that are classified as background
                keep_level = labels_per_level > 0

                boxes_per_level = boxes_per_level[keep_level]
                scores_per_level = scores_per_level[keep_level]
                labels_per_level = labels_per_level[keep_level]

                image_boxes.append(boxes_per_level)
                image_scores.append(scores_per_level)
                image_labels.append(labels_per_level)

            image_boxes = torch.cat(image_boxes, dim=0)
            image_scores = torch.cat(image_scores, dim=0)
            image_labels = torch.cat(image_labels, dim=0)

            keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
            keep = keep[: self.detections_per_img]

            detections.append(
                {
                    "boxes": image_boxes[keep],
                    "scores": image_scores[keep],
                    "labels": image_labels[keep],
                }
            )

        return detections

    def forward(
        self,
        images: List[Tensor],
        targets: Optional[List[Dict[str, Tensor]]] = None,
    ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
        """
        Args:
            images (list[Tensor]): images to be processed
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).
        """
        if self.training:

            # -- Modified to ensure transforms are set --
            if self.train_transform is None:
                raise ValueError(
                    "train_transform is not set. Please set it before training."
                )

            if targets is None:
                torch._assert(False, "targets should not be none when in training mode")
            else:
                for target in targets:
                    boxes = target["boxes"]
                    torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
                    torch._assert(
                        len(boxes.shape) == 2 and boxes.shape[-1] == 4,
                        f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
                    )

        # -- Modified to ensure transforms are set --
        else:
            if self.test_transform is None:
                raise ValueError(
                    "test_transform is not set. Please set it before testing."
                )

        original_image_sizes: List[Tuple[int, int]] = []
        for img in images:
            val = img.shape[-2:]
            torch._assert(
                len(val) == 2,
                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
            )
            original_image_sizes.append((val[0], val[1]))

        # -- Modified to use train or test transform --

        # Print the image sizes for debugging
        # for idx, size in enumerate(original_image_sizes):
        #     print(f"Image {idx} size: {size}")


        if self.training:
            images, targets = self.train_transform(images, targets)
        else:
            images, targets = self.test_transform(images, targets)

        # Print the transformed image sizes for debugging
        # for idx, size in enumerate(images.image_sizes):
        #     print(f"Transformed Image {idx} size: {size}")

        # Check for degenerate boxes
        if targets is not None:
            for target_idx, target in enumerate(targets):
                boxes = target["boxes"]
                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
                if degenerate_boxes.any():
                    # print the first degenerate box
                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
                    degen_bb: List[float] = boxes[bb_idx].tolist()
                    torch._assert(
                        False,
                        f"All bounding boxes should have positive height and width. Found invalid box {degen_bb} for target at index {target_idx}.",
                    )

        # get the features from the backbone
        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([("0", features)])

        features = list(features.values())

        # compute the fcos heads outputs using the features
        head_outputs = self.head(features)

        # create the set of anchors
        anchors = self.anchor_generator(images, features)
        # recover level sizes
        num_anchors_per_level = [x.size(2) * x.size(3) for x in features]

        losses = {}
        detections: List[Dict[str, Tensor]] = []
        if self.training:
            if targets is None:
                torch._assert(False, "targets should not be none when in training mode")
            else:
                # compute the losses
                losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level)
        else:
            # split outputs per level
            split_head_outputs: Dict[str, List[Tensor]] = {}
            for k in head_outputs:
                split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
            split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]

            # compute the detections
            detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)

            # --- Modified ---
            # use test_transform to postprocess detections
            detections = self.test_transform.postprocess(
                detections, images.image_sizes, original_image_sizes
            )

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
            return losses, detections
        return self.eager_outputs(losses, detections)

def fcos_resnet50_fpn(
    *,
    dataset: str = "coco",
    weights_path: str = None,
    change_head: bool = False,
    new_head_weights: bool = False,
    progress: bool = True,
    num_classes: Optional[int] = None,
    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
    trainable_backbone_layers: Optional[int] = None,
    loss_type: str = 'log',      # log, generalized_log, huber, softplus
    hyper: float = 1.0,          # hyperparameter for the custom loss term
    **kwargs: Any,
) -> FCOS:
    """
    Constructs a FCOS model with a ResNet-50-FPN backbone.

    Args:
        dataset (str):       "coco" (default) or another name (e.g. "kitti").
        weights:             which FCOS weights to load (COCO only).
        num_classes:         number of output classes for your dataset.
        loss_type, hyper:    passed to FCOS constructor.
        **kwargs:            other FCOS args (e.g. anchor sizes, strides).
    """
    # 1) verify weight enums
    if weights_path is not None:
        weights = torch.load(weights_path, weights_only=True)
        weights_backbone = None
    else:
        weights = None
        weights_backbone = ResNet50_Weights.verify(weights_backbone)

    if num_classes is None:
        raise ValueError(
            "num_classes must be specified when weights are not provided. "
            "It should match the number of classes in your dataset."
        )

    # 3) build backbone + FPN
    is_trained = (weights is not None) or (weights_backbone is not None)
    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
      
    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d

    backbone = resnet50(
        weights=weights_backbone,
        progress=progress,
        norm_layer=norm_layer
    )
    backbone = _resnet_fpn_extractor(
        backbone,
        trainable_backbone_layers,
        returned_layers=[2, 3, 4],
        extra_blocks=LastLevelP6P7(256, 256)
    )

    # 2) set the transforms
    if dataset == 'mtsd' or dataset == 'ptsd' or dataset == 'mtsd_meta':
        
        train_transform = GeneralizedRCNNTransform(
            min_size=(2048,),   # shorter side unconstrained beyond matching the longer side
            max_size=2048,      # cap the longer side at 2048 px
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )
        
        test_transform = GeneralizedRCNNTransform(
            min_size=(2048,),   # shorter side unconstrained beyond matching the longer side
            max_size=2048,      # cap the longer side at 2048 px
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )
        
    elif dataset == 'coco':
        train_transform = GeneralizedRCNNTransform(
            min_size=800,
            max_size=1333,
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )

        test_transform = GeneralizedRCNNTransform(
            min_size=800,
            max_size=1333,
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )

    elif dataset == 'gtsdb':
        train_transform = GeneralizedRCNNTransform(
            min_size=(1360,),   # shorter side unconstrained beyond matching the longer side
            max_size=1360,      # cap the longer side at 1360 px
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )
        
        test_transform = GeneralizedRCNNTransform(
            min_size=(1360,),   # shorter side unconstrained beyond matching the longer side
            max_size=1360,      # cap the longer side at 2048 px
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )

    else:
        raise ValueError(f"Unsupported dataset: {dataset}. Supported datasets are 'coco' and 'mtsd'.")


    # 4) instantiate the FCOS model
    model = FCOS(backbone, num_classes, train_transform=train_transform, test_transform=test_transform, loss_type=loss_type, hyper=hyper, **kwargs)

    # 5) load pretrained weights appropriately
    if weights is not None:
        
        if not change_head and not new_head_weights:
            print("Loading pretrained weights for FCOS model.")
            model.load_state_dict(weights, strict=True)
        else:
            print("Loading pretrained weights for FCOS model with modified head.")
            # on any other dataset: strip off the COCO heads
            # Print the number of keys in the state dict
            filtered = {
                k: v
                for k, v in weights.items()
                if k not in (
                    "head.classification_head.cls_logits.weight",
                    "head.classification_head.cls_logits.bias",
                )
            }

            model.load_state_dict(filtered, strict=False)

    return model