from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import RoIAlign
from PIL import Image
from torchvision import transforms
from augment.transforms import paste_to_batch
from torchvision.ops import box_iou
from detectron2.config import configurable
from detectron2.structures import Boxes, ImageList, pairwise_iou
from typing import Dict, List, Optional, Tuple


@META_ARCH_REGISTRY.register()
class GuassianGeneralizedRCNN(GeneralizedRCNN):
    def forward(self, batched_inputs, branch="supervised", danchor=False, norm=False, tau=None):
        if not self.training:
            return self.inference(batched_inputs)

        if "instances" in batched_inputs[0]:
            if norm:
                images = self.preprocess_image_norm(batched_inputs)
            else:
                images = self.preprocess_image(batched_inputs)
            gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
        else:
            images = self.preprocess_image(batched_inputs)
            gt_instances = None

        features = self.backbone(images.tensor)

        if branch == "supervised":
            # Region proposal network
            proposals_rpn, proposal_losses = self.proposal_generator(
                images, features, gt_instances, danchor=danchor
            )

            # # roi_head lower branch
            _, detector_losses = self.roi_heads(
                images, features, proposals_rpn, gt_instances, branch=branch
            )

            losses = {}
            losses.update(detector_losses)
            losses.update(proposal_losses)
            return losses, [], [], None

        elif branch == "unsup_data_weak":
            # Region proposal network
            proposals_rpn, _ = self.proposal_generator(
                images, features, None, compute_loss=False
            )

            if "image_labels" in batched_inputs[0]:
                image_labels = [x["image_labels"].to(self.device) for x in batched_inputs]
            else:
                image_labels = None

            proposals_roih, ROI_predictions = self.roi_heads(
                images,
                features,
                proposals_rpn,
                targets=None,
                compute_loss=False,
                branch=branch,
                image_labels=image_labels,
            )

            return {}, proposals_rpn, proposals_roih, ROI_predictions

        elif branch == "unsupervised":
            proposals_rpn, proposal_losses = self.proposal_generator(
                images, features, gt_instances, branch=branch, danchor=danchor, tau=tau
            )

            _, detector_losses = self.roi_heads(
                images, features, proposals_rpn, gt_instances, branch=branch, tau=tau
            )

            losses = {}
            losses.update(detector_losses)
            losses.update(proposal_losses)
            return losses, [], [], None
    

    def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]):
        """
        Normalize, pad and batch the input images.
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        if torch.sum(self.pixel_mean) < 3:
            images = [((x / 255.0) - self.pixel_mean) / self.pixel_std for x in images]
        else:
            images = [(x - self.pixel_mean) / self.pixel_std for x in images]
        images = ImageList.from_tensors(images, self.backbone.size_divisibility)
        return images