import json
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmdet.models.builder import HEADS
from mmdet.core import bbox2roi

from models.fvit_head import FViTBBoxHead, FViTRoIHead
from models.custom_attention import CustomSelfCrossAttentionLayer


@HEADS.register_module()
class MCTBBoxHead(FViTBBoxHead):
    def __init__(self,
                 fixed_temperature=0,
                 learned_temperature=50.0,
                 class_embed=None,
                 seen_classes=None,
                 all_classes=None,
                 vlm_temperature=75.0,
                 alpha=0.0,
                 beta=0.5,
                 learn_bg=False,
                 embed_level=True,
                 unified_heads=True,
                 **kwargs):
        super(MCTBBoxHead, self).__init__(
            fixed_temperature=fixed_temperature,
            learned_temperature=learned_temperature,
            class_embed=class_embed,
            seen_classes=seen_classes,
            all_classes=all_classes,
            vlm_temperature=vlm_temperature,
            alpha=alpha,
            beta=beta,
            learn_bg=learn_bg,
            **kwargs)
        
        self.embed_level = embed_level
        self.unified_heads = unified_heads

    def forward(self, x, vlm_box_feats=None):
        all_embed = self.all_embed.type_as(x)

        if vlm_box_feats is not None:
            normalized_vlm_box_feats = F.normalize(vlm_box_feats, dim=-1, p=2)
        else:
            normalized_vlm_box_feats = None

        if self.num_shared_convs > 0:
            for conv in self.shared_convs:
                x = conv(x)

        if self.num_shared_fcs > 0:
            if self.with_avg_pool:
                x = self.avg_pool(x)

            x = x.flatten(1)

            for fc in self.shared_fcs:
                x = self.relu(fc(x))

        # separate branches
        x_cls = x
        x_reg = x

        for conv in self.cls_convs:
            x_cls = conv(x_cls)
        if x_cls.dim() > 2:
            if self.with_avg_pool:
                x_cls = self.avg_pool(x_cls)
            x_cls = x_cls.flatten(1)
        for fc in self.cls_fcs[:-1]:
            x_cls = self.relu(fc(x_cls))
        x_cls = self.cls_fcs[-1](x_cls) # discrad the last relu

        if self.unified_heads:
            x_reg = F.relu(x_cls)
        else:
            for conv in self.reg_convs:
                x_reg = conv(x_reg)
            if x_reg.dim() > 2:
                if self.with_avg_pool:
                    x_reg = self.avg_pool(x_reg)
                x_reg = x_reg.flatten(1)
            for fc in self.reg_fcs:
                x_reg = self.relu(fc(x_reg))

        bbox_pred = self.fc_reg(x_reg)
        normalized_x_cls = F.normalize(x_cls, dim=-1)

        if self.embed_level and normalized_vlm_box_feats is not None:
            normalized_x_cls = normalized_x_cls + normalized_vlm_box_feats
        normalized_x_cls = F.normalize(normalized_x_cls, dim=-1)

        cls_score = normalized_x_cls @ all_embed * self.detect_temperature

        if not self.training and normalized_vlm_box_feats is not None:
            cls_score = cls_score.softmax(dim=-1)
            vlm_score = normalized_vlm_box_feats @ all_embed * self.vlm_temperature
            vlm_score = vlm_score.softmax(dim=-1)

            cls_score[:, self.base_idx] = cls_score[:, self.base_idx] ** (
                    1 - self.alpha) * vlm_score[:, self.base_idx] ** self.alpha
            cls_score[:, self.novel_idx] = cls_score[:, self.novel_idx] ** (
                    1 - self.beta) * vlm_score[:, self.novel_idx] ** self.beta

        return cls_score, bbox_pred


@HEADS.register_module()
class MCTRoIHead(FViTRoIHead):
    def __init__(self, vlm_roi_extractor=None,
                 feature_level="sa_ca_ffn", num_heads=4,
                 **kwargs):
        super(MCTRoIHead, self).__init__(
            vlm_roi_extractor=vlm_roi_extractor,
            **kwargs)

        self.feauture_level = feature_level
        if feature_level == "sa_ca_ffn":
            self.cross_attention_layer = CustomSelfCrossAttentionLayer(
                q_dim=self.bbox_roi_extractor.out_channels,
                kv_dim=self.vlm_roi_extractor.out_channels,
                out_dim=self.bbox_roi_extractor.out_channels,
                num_heads=num_heads)

    def _bbox_forward(self, x, rois, vlm_feat=None):
        """Box head forward function used in both training and testing."""
        # TODO: a more flexible way to decide which feature maps to use
        bbox_feats = self.bbox_roi_extractor(
            x[:self.bbox_roi_extractor.num_inputs], rois)
        if self.with_shared_head:
            bbox_feats = self.shared_head(bbox_feats)
        vlm_roi_feats = None
        if vlm_feat is not None:
            vlm_roi_feats = self.vlm_roi_extractor([vlm_feat], rois)
            if self.feauture_level == "sa_ca_ffn":
                num_bboxes, bbox_channels, size1, size2 = bbox_feats.shape
                vlm_channels = vlm_roi_feats.shape[1]
                bbox_feats = bbox_feats.view(num_bboxes, bbox_channels, -1).permute(0, 2, 1)
                vlm_bbox_feats = vlm_roi_feats.view(num_bboxes, vlm_channels, -1).permute(0, 2, 1)
                bbox_feats = self.cross_attention_layer(bbox_feats, vlm_bbox_feats)
                bbox_feats = bbox_feats.permute(0, 2, 1).view(num_bboxes, bbox_channels, size1, size2)
            vlm_roi_feats = vlm_roi_feats.view(rois.shape[0], vlm_feat.shape[1], -1).mean(dim=-1)
        cls_score, bbox_pred = self.bbox_head(bbox_feats, vlm_roi_feats)

        bbox_results = dict(
            cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
        return bbox_results

    def forward_train(self,
                      x,
                      img_metas,
                      proposal_list,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None,
                      **kwargs):
        """
        Args:
            x (list[Tensor]): list of multi-level img features.
            img_metas (list[dict]): list of image info dict where each dict
                has: 'img_shape', 'scale_factor', 'flip', and may also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.
            proposals (list[Tensors]): list of region proposals.
            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.
            gt_masks (None | Tensor) : true segmentation masks for each box
                used if the architecture supports a segmentation task.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        # assign gts and sample proposals
        if self.with_bbox or self.with_mask:
            num_imgs = len(img_metas)
            if gt_bboxes_ignore is None:
                gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []
            for i in range(num_imgs):
                assign_result = self.bbox_assigner.assign(
                    proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
                    gt_labels[i])
                sampling_result = self.bbox_sampler.sample(
                    assign_result,
                    proposal_list[i],
                    gt_bboxes[i],
                    gt_labels[i],
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results.append(sampling_result)

        losses = dict()
        # bbox head forward and loss
        if self.with_bbox:
            vlm_feat = None
            res_feats = kwargs.get("res_feats", None)
            if res_feats is not None:
                vlm_feat = res_feats[-1]
            bbox_results = self._bbox_forward_train(x, sampling_results,
                                                    gt_bboxes, gt_labels,
                                                    img_metas, vlm_feat=vlm_feat)
            losses.update(bbox_results['loss_bbox'])

        # mask head forward and loss
        if self.with_mask:
            mask_results = self._mask_forward_train(x, sampling_results,
                                                    bbox_results['bbox_feats'],
                                                    gt_masks, img_metas)
            losses.update(mask_results['loss_mask'])

        return losses

    def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
                            img_metas, vlm_feat=None):
        """Run forward function and calculate loss for box head in training."""
        rois = bbox2roi([res.bboxes for res in sampling_results])
        bbox_results = self._bbox_forward(x, rois, vlm_feat=vlm_feat)

        bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
                                                  gt_labels, self.train_cfg)
        loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
                                        bbox_results['bbox_pred'], rois,
                                        *bbox_targets)

        bbox_results.update(loss_bbox=loss_bbox)
        return bbox_results


@HEADS.register_module()
class MCTTransferBBoxHead(MCTBBoxHead):
    def __init__(self, *args, **kwargs):
        super(MCTTransferBBoxHead, self).__init__(*args, **kwargs)
        self.fixed_temperature = kwargs['fixed_temperature']

    def forward(self, x, vlm_box_feats=None):

        all_embed = self.all_embed.type_as(x)

        if vlm_box_feats is not None:
            normalized_vlm_box_feats = F.normalize(vlm_box_feats, dim=-1, p=2)
        else:
            # origin vlm
            normalized_vlm_box_feats = None

        if self.num_shared_convs > 0:
            for conv in self.shared_convs:
                x = conv(x)

        if self.num_shared_fcs > 0:
            if self.with_avg_pool:
                x = self.avg_pool(x)

            x = x.flatten(1)

            for fc in self.shared_fcs:
                x = self.relu(fc(x))

        # separate branches
        x_cls = x
        x_reg = x

        for conv in self.cls_convs:
            x_cls = conv(x_cls)
        if x_cls.dim() > 2:
            if self.with_avg_pool:
                x_cls = self.avg_pool(x_cls)
            x_cls = x_cls.flatten(1)
        for fc in self.cls_fcs[:-1]:
            x_cls = self.relu(fc(x_cls))
        x_cls = self.cls_fcs[-1](x_cls) # discrad the last relu

        if self.unified_heads:
            x_reg = F.relu(x_cls)
        else:
            x_cls = F.relu(x_cls) # keep same as f-vit
            for conv in self.reg_convs:
                x_reg = conv(x_reg)
            if x_reg.dim() > 2:
                if self.with_avg_pool:
                    x_reg = self.avg_pool(x_reg)
                x_reg = x_reg.flatten(1)
            for fc in self.reg_fcs:
                x_reg = self.relu(fc(x_reg))

        bbox_pred = self.fc_reg(x_reg)
        normalized_x_cls = F.normalize(x_cls, dim=-1)

        if self.embed_level and normalized_vlm_box_feats is not None:
            normalized_x_cls = normalized_x_cls + normalized_vlm_box_feats
        normalized_x_cls = F.normalize(normalized_x_cls, dim=-1)

        assert not self.training
        cls_score = normalized_x_cls @ all_embed * self.fixed_temperature
        cls_score = cls_score.softmax(dim=-1)
        vlm_score = normalized_vlm_box_feats @ all_embed * self.vlm_temperature
        vlm_score = vlm_score.softmax(dim=-1)

        cls_score = cls_score ** (1 - self.alpha) * vlm_score ** self.alpha

        return cls_score, bbox_pred
