from typing import Any, Callable, List, Optional, Tuple, Union
from collections import OrderedDict
import warnings

from ..utils.transform import GeneralizedRCNNTransform

import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.rpn import RegionProposalNetwork, RPNHead
from torchvision.ops import misc as misc_nn_ops
from torchvision.models._utils import _ovewrite_value_param
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers

from torchvision.models.resnet import resnet50, ResNet50_Weights
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection.faster_rcnn import FasterRCNN_ResNet50_FPN_Weights

from .roi_heads_attack import RoIHeadsAttack

def _default_anchorgen():
    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
    return AnchorGenerator(anchor_sizes, aspect_ratios)

def _sign_anchorgen():
    anchor_sizes = ((8,), (16,), (32,), (64,), (128,))
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
    return AnchorGenerator(anchor_sizes, aspect_ratios)

dataset_norms = {
    "default": {
        "mean": [0.485, 0.456, 0.406],
        "std": [0.229, 0.224, 0.225],
    },
}

class FasterRCNN(GeneralizedRCNN):

    """
    Implements Faster R-CNN.
    Modified version of FasterRCNN from torchvision.models.detection.faster_rcnn.py
    Changes are indicated with comments.
    """

    def __init__(
        self,
        backbone,
        num_classes,
        # transform parameters
        train_transform,
        test_transform,
        # RPN parameters
        rpn_anchor_generator=None,
        rpn_head=None,
        rpn_pre_nms_top_n_train=2000,
        rpn_pre_nms_top_n_test=1000,
        rpn_post_nms_top_n_train=2000,
        rpn_post_nms_top_n_test=1000,
        rpn_nms_thresh=0.7,
        rpn_fg_iou_thresh=0.7,
        rpn_bg_iou_thresh=0.3,
        rpn_batch_size_per_image=256,
        rpn_positive_fraction=0.5,
        rpn_score_thresh=0.0,
        # Box parameters
        box_roi_pool=None,
        box_head=None,
        box_predictor=None,
        box_score_thresh=0.05,
        box_nms_thresh=0.5,
        box_detections_per_img=100,
        box_fg_iou_thresh=0.5,
        box_bg_iou_thresh=0.5,
        box_batch_size_per_image=512,
        box_positive_fraction=0.25,
        bbox_reg_weights=None,
        lambda_attack: int = 1, # <--- Added lambda_attack
        **kwargs,
    ):

        if not hasattr(backbone, "out_channels"):
            raise ValueError(
                "backbone should contain an attribute out_channels "
                "specifying the number of output channels (assumed to be the "
                "same for all the levels)"
            )

        if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
            raise TypeError(
                f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
            )
        if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
            raise TypeError(
                f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
            )

        if num_classes is not None:
            if box_predictor is not None:
                raise ValueError("num_classes should be None when box_predictor is specified")
        else:
            if box_predictor is None:
                raise ValueError("num_classes should not be None when box_predictor is not specified")

        out_channels = backbone.out_channels

        if rpn_anchor_generator is None:
            rpn_anchor_generator = _default_anchorgen()
        if rpn_head is None:
            rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])

        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)

        rpn = RegionProposalNetwork(
            rpn_anchor_generator,
            rpn_head,
            rpn_fg_iou_thresh,
            rpn_bg_iou_thresh,
            rpn_batch_size_per_image,
            rpn_positive_fraction,
            rpn_pre_nms_top_n,
            rpn_post_nms_top_n,
            rpn_nms_thresh,
            score_thresh=rpn_score_thresh,
        )

        if box_roi_pool is None:
            box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)

        if box_head is None:
            resolution = box_roi_pool.output_size[0]
            representation_size = 1024
            box_head = TwoMLPHead(out_channels * resolution**2, representation_size)

        if box_predictor is None:
            representation_size = 1024
            box_predictor = FastRCNNPredictor(representation_size, num_classes)

        roi_heads = RoIHeadsAttack(
            # Box
            box_roi_pool,
            box_head,
            box_predictor,
            box_fg_iou_thresh,
            box_bg_iou_thresh,
            box_batch_size_per_image,
            box_positive_fraction,
            bbox_reg_weights,
            box_score_thresh,
            box_nms_thresh,
            box_detections_per_img,
            lambda_attack=lambda_attack,  # <--- Added lambda_attack
        )
        
        super().__init__(backbone, rpn, roi_heads, None)

        # Note: The current implementation sets these to None by default.
        # These therefore need to be set explicitly aftwards 
        # The provided wrapper has a method to set these transforms 
        self.train_transform = train_transform
        self.test_transform = test_transform

    # This overrides the forward method of GeneralizedRCNN
    def forward(
        self,
        images: list[torch.Tensor],
        targets: Optional[list[dict[str, torch.Tensor]]] = None,
    ) -> tuple[dict[str, torch.Tensor], list[dict[str, torch.Tensor]]]:
        """
        Args:
            images (list[Tensor]): images to be processed
            targets (list[dict[str, tensor]]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).

        """
        if self.training:
            # -- Modified ----
            # In training mode, we expect train_transform to be set
            if self.train_transform is None:
                torch._assert(
                    False,
                    "train_transform should not be None when in training mode. "
                    "Please provide a valid train_transform.",
                )
            if targets is None:
                torch._assert(False, "targets should not be none when in training mode")
            else:
                for target in targets:
                    boxes = target["boxes"]
                    if isinstance(boxes, torch.Tensor):
                        torch._assert(
                            len(boxes.shape) == 2 and boxes.shape[-1] == 4,
                            f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
                        )
                    else:
                        torch._assert(
                            False,
                            f"Expected target boxes to be of type Tensor, got {type(boxes)}.",
                        )
        else:
            # -- Modified ----
            # In evaluation mode, we expect test_transform to be set
            if self.test_transform is None:
                torch._assert(
                    False,
                    "test_transform should not be None when in evaluation mode. "
                    "Please provide a valid test_transform.",
                )

        original_image_sizes: list[tuple[int, int]] = []
        for img in images:
            val = img.shape[-2:]
            torch._assert(
                len(val) == 2,
                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
            )
            original_image_sizes.append((val[0], val[1]))

        # ---- Modified ---- 
        # Allows training and testing to use different transforms
        if self.training:
            images, targets = self.train_transform(images, targets)
        else:
            images, targets = self.test_transform(images, targets)

        # Check for degenerate boxes
        # TODO: Move this to a function
        if targets is not None:
            for target_idx, target in enumerate(targets):
                boxes = target["boxes"]
                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
                if degenerate_boxes.any():
                    # print the first degenerate box
                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
                    degen_bb: list[float] = boxes[bb_idx].tolist()
                    torch._assert(
                        False,
                        "All bounding boxes should have positive height and width."
                        f" Found invalid box {degen_bb} for target at index {target_idx}.",
                    )

        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([("0", features)])
        proposals, proposal_losses = self.rpn(images, features, targets)
        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
        
        # ---- Modified ----
        if self.training:
            detections = self.train_transform.postprocess(
                detections, images.image_sizes, original_image_sizes
            )
        else:
            detections = self.test_transform.postprocess(
                detections, images.image_sizes, original_image_sizes
            )

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
            return losses, detections
        else:
            return self.eager_outputs(losses, detections)


class TwoMLPHead(nn.Module):
    """
    Standard heads for FPN-based models

    Args:
        in_channels (int): number of input channels
        representation_size (int): size of the intermediate representation
    """

    def __init__(self, in_channels, representation_size):
        super().__init__()

        self.fc6 = nn.Linear(in_channels, representation_size)
        self.fc7 = nn.Linear(representation_size, representation_size)

    def forward(self, x):
        x = x.flatten(start_dim=1)

        x = F.relu(self.fc6(x))
        x = F.relu(self.fc7(x))

        return x


class FastRCNNConvFCHead(nn.Sequential):
    def __init__(
        self,
        input_size: Tuple[int, int, int],
        conv_layers: List[int],
        fc_layers: List[int],
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ):
        """
        Args:
            input_size (Tuple[int, int, int]): the input size in CHW format.
            conv_layers (list): feature dimensions of each Convolution layer
            fc_layers (list): feature dimensions of each FCN layer
            norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
        """
        in_channels, in_height, in_width = input_size

        blocks = []
        previous_channels = in_channels
        for current_channels in conv_layers:
            blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
            previous_channels = current_channels
        blocks.append(nn.Flatten())
        previous_channels = previous_channels * in_height * in_width
        for current_channels in fc_layers:
            blocks.append(nn.Linear(previous_channels, current_channels))
            blocks.append(nn.ReLU(inplace=True))
            previous_channels = current_channels

        super().__init__(*blocks)
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d):
                nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)


class FastRCNNPredictor(nn.Module):
    """
    Standard classification + bounding box regression layers
    for Fast R-CNN.

    Args:
        in_channels (int): number of input channels
        num_classes (int): number of output classes (including background)
    """

    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.cls_score = nn.Linear(in_channels, num_classes)
        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)

    def forward(self, x):
        if x.dim() == 4:
            torch._assert(
                list(x.shape[2:]) == [1, 1],
                f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
            )
        x = x.flatten(start_dim=1)
        scores = self.cls_score(x)
        bbox_deltas = self.bbox_pred(x)

        return scores, bbox_deltas

def fasterrcnn_resnet50_fpn_attack(
    *,
    dataset: str = "coco",
    weights_path: str = None,
    change_head: bool = False,
    new_head_weights: bool = False,
    progress: bool = True,
    num_classes: Optional[int] = None,
    lambda_attack: int = 1,  # <--- Added lambda_attack
    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
    trainable_backbone_layers: Optional[int] = None,
    **kwargs: Any,
) -> FasterRCNN:

    """
    Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.

    Args:
        dataset (str):       "coco" (default) or another name (e.g. "kitti").
        weights:             which FCOS weights to load (COCO only).
        num_classes:         number of output classes for your dataset.
        loss_type, hyper:    passed to FCOS constructor.
        **kwargs:            other FCOS args (e.g. anchor sizes, strides).
    """

    # 1) verify weight enums
    if weights_path is not None:
        weights = torch.load(weights_path, weights_only=True)
        weights_backbone = None
    else:
        weights = None
        weights_backbone = ResNet50_Weights.verify(weights_backbone)

    if num_classes is None:
        raise ValueError(
            "num_classes must be specified when weights are not provided. "
            "It should match the number of classes in your dataset."
        )

    # 3) build backbone + FPN
    is_trained = (weights is not None) or (weights_backbone is not None)
    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)

    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d

    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)

    if dataset == 'mtsd' or dataset == 'ptsd' or dataset == 'mtsd_meta':
        # Use a different anchor generator for the MTSDB dataset
        rpn_anchor_generator = _sign_anchorgen()
        train_transform = GeneralizedRCNNTransform(
            min_size=(2048,),   # shorter side unconstrained beyond matching the longer side
            max_size=2048,      # cap the longer side at 2048 px
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )
        
        test_transform = GeneralizedRCNNTransform(
                min_size=(2048,),   # shorter side unconstrained beyond matching the longer side
                max_size=2048,      # cap the longer side at 2048 px
                image_mean=dataset_norms["default"]["mean"],
                image_std=dataset_norms["default"]["std"],
        )

    elif dataset == 'coco':
        # Use the default anchor generator for COCO and other datasets
        rpn_anchor_generator = _default_anchorgen()

        train_transform = GeneralizedRCNNTransform(
            min_size=800,
            max_size=1333,
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )

        test_transform = GeneralizedRCNNTransform(
            min_size=800,
            max_size=1333,
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )

    elif dataset == 'gtsdb':
        rpn_anchor_generator = _sign_anchorgen()
        train_transform = GeneralizedRCNNTransform(
            min_size=(1360,),   # shorter side unconstrained beyond matching the longer side
            max_size=1360,      # cap the longer side at 1360 px
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )
        
        test_transform = GeneralizedRCNNTransform(
            min_size=(1360,),   # shorter side unconstrained beyond matching the longer side
            max_size=1360,      # cap the longer side at 2048 px
            image_mean=dataset_norms["default"]["mean"],
            image_std=dataset_norms["default"]["std"],
        )

    else:
        raise ValueError(f"Unsupported dataset: {dataset}. Supported datasets are 'coco' and 'mtsd'.")

    # Note: The train and test transforms need to be set up separately
    # Faster R-CNN requires this and will throw an error if not provided
    model = FasterRCNN(backbone, num_classes, train_transform, test_transform, lambda_attack=lambda_attack, rpn_anchor_generator=rpn_anchor_generator, **kwargs)

    if weights is not None:

        #state = weights.get_state_dict(progress=progress, check_hash=True)
        print(f'Change head: {change_head}, new head weights: {new_head_weights}')
        if not change_head or not new_head_weights:
            print("Loading weights without changing the head.")
            model.load_state_dict(weights, strict=True)

        else:
            # on any other dataset: strip off the COCO heads
            # Print the number of keys in the state dict
            print(f"Loading weights with changed head.")
            filtered = {
                k: v
                for k, v in weights.items()
                if k not in (
                    "roi_heads.box_predictor.cls_score.weight",
                    "roi_heads.box_predictor.cls_score.bias",
                    "roi_heads.box_predictor.bbox_pred.weight",
                    "roi_heads.box_predictor.bbox_pred.bias",
                )
            }

            model.load_state_dict(filtered, strict=False)

    return model
