# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import os
import json
import copy
from typing import Dict, List, Optional, Tuple, Union
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmengine.runner import load_checkpoint, load_state_dict
from mmengine.model import is_model_wrapper
from mmengine import Config
from mmengine.structures import InstanceData
from mmdet.utils import ConfigType, OptConfigType, InstanceList
from mmdet.structures.bbox import bbox2roi
from mmdet.models.utils import multi_apply
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, bbox_overlaps
from mmdet.models.dense_heads.atss_vlfusion_head import convert_grounding_to_cls_scores
from mmdet.structures import OptSampleList, SampleList
from ..layers import SinePositionalEncoding, CdnQueryGenerator
from ..layers import inverse_sigmoid
from .incre_ddetr_incre import incre_incre_DeformableDETR

@MODELS.register_module()
class incre_gcd_DeformableDETR(incre_incre_DeformableDETR):
    """Implementation of `Grounding DINO: Marrying DINO with Grounded Pre-
    Training for Open-Set Object Detection.

    <https://arxiv.org/abs/2303.05499>`_

    Code is modified from the `official github repo
    <https://github.com/IDEA-Research/GroundingDINO>`_.
    """

    def forward_transformer(
        self,
        img_feats: Tuple[Tensor],
        batch_data_samples: OptSampleList = None,
        aux_dict: Dict = None,
        deconf_query: Tensor=None,
    ) -> Dict:
        
        encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
            img_feats, batch_data_samples)

        encoder_outputs_dict = self.forward_encoder(
            **encoder_inputs_dict)        
        if self.training:
            new_decoder_inputs_dict = decoder_inputs_dict.copy()
            old_decoder_inputs_dict = decoder_inputs_dict.copy()
            old_decoder_inputs_dict_deconf = decoder_inputs_dict.copy()
            new_tmp_dec_in, new_head_inputs_dict = self.pre_decoder(
                **encoder_outputs_dict)
            new_decoder_inputs_dict.update(new_tmp_dec_in)
            new_query=new_tmp_dec_in['query'].clone()
            new_decoder_outputs_dict = self.forward_decoder(**new_decoder_inputs_dict)
            new_head_inputs_dict.update(new_decoder_outputs_dict)

            old_tmp_dec_in, old_head_inputs_dict = self.pre_decoder_old(
                **encoder_outputs_dict, aux_dict=aux_dict, batch_data_samples=batch_data_samples)
            old_decoder_inputs_dict.update(old_tmp_dec_in)
            old_decoder_outputs_dict = self.forward_decoder(**old_decoder_inputs_dict)
            old_head_inputs_dict.update(old_decoder_outputs_dict)

            old_head_inputs_dict_deconf=None
            if deconf_query is not None:
                aux_dict_deconf=aux_dict.copy()
                aux_dict_deconf['query']=deconf_query.detach()
                old_tmp_dec_in_deconf, old_head_inputs_dict_deconf = self.pre_decoder_old(
                    **encoder_outputs_dict, aux_dict=aux_dict_deconf, batch_data_samples=batch_data_samples)
                old_decoder_inputs_dict_deconf.update(old_tmp_dec_in_deconf)
                old_decoder_outputs_dict_deconf = self.forward_decoder(**old_decoder_inputs_dict_deconf)
                old_head_inputs_dict_deconf.update(old_decoder_outputs_dict_deconf)

            return new_head_inputs_dict, old_head_inputs_dict,new_query,old_head_inputs_dict_deconf
        
        else:
            tmp_dec_in, head_inputs_dict = self.pre_decoder(
                **encoder_outputs_dict)
            decoder_inputs_dict.update(tmp_dec_in)      
            decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
            head_inputs_dict.update(decoder_outputs_dict)

            return head_inputs_dict      
        
    def pre_decoder_old(
        self,
        memory: Tensor,
        memory_mask: Tensor,
        spatial_shapes: Tensor,
        aux_dict: Dict = None,
        batch_data_samples: OptSampleList = None,
    ) -> Tuple[Dict]:
        aux_query, aux_reference, aux_query_pos = generate_incre_points(self.incre_cfg, aux_dict)   
        batch_size, _, c = memory.shape
        if self.as_two_stage:
            output_memory, output_proposals = \
                self.gen_encoder_output_proposals(
                    memory, memory_mask, spatial_shapes)
            enc_outputs_class = self.bbox_head.cls_branches[
                self.decoder.num_layers](
                    output_memory)
            enc_outputs_coord_unact = self.bbox_head.reg_branches[
                self.decoder.num_layers](output_memory) + output_proposals
            enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
            # We only use the first channel in enc_outputs_class as foreground,
            # the other (num_classes - 1) channels are actually not used.
            # Its targets are set to be 0s, which indicates the first
            # class (foreground) because we use [0, num_classes - 1] to
            # indicate class labels, background class is indicated by
            # num_classes (similar convention in RPN).
            # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa
            # This follows the official implementation of Deformable DETR.
            topk_proposals = torch.topk(
                enc_outputs_class[..., 0], self.num_queries, dim=1)[1]
            topk_coords_unact = torch.gather(
                enc_outputs_coord_unact, 1,
                topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
            topk_coords_unact = topk_coords_unact.detach()
            reference_points = topk_coords_unact.sigmoid()
            pos_trans_out = self.pos_trans_fc(
                self.get_proposal_pos_embed(topk_coords_unact))
            pos_trans_out = self.pos_trans_norm(pos_trans_out)
            query_pos, query = torch.split(pos_trans_out, c, dim=2)
            
        else:
            enc_outputs_class, enc_outputs_coord = None, None
            query_embed = self.query_embedding.weight
            query_pos, query = torch.split(query_embed, c, dim=1)
            query_pos = query_pos.unsqueeze(0).expand(batch_size, -1, -1)
            query = query.unsqueeze(0).expand(batch_size, -1, -1)
            reference_points = self.reference_points_fc(query_pos).sigmoid()
        query=aux_query
        reference_points=aux_reference
        query_pos=aux_query_pos
        decoder_inputs_dict = dict(
            query=query,
            query_pos=query_pos,
            memory=memory,
            reference_points=reference_points)
        head_inputs_dict = dict(
            enc_outputs_class=enc_outputs_class,
            enc_outputs_coord=enc_outputs_coord) if self.training else dict()
        return decoder_inputs_dict, head_inputs_dict 

    def pre_decoder_new(
        self,
        memory: Tensor,
        memory_mask: Tensor,
        spatial_shapes: Tensor,
        batch_data_samples: OptSampleList = None,
    ) -> Tuple[Dict]: 
        batch_size, _, c = memory.shape
        if self.as_two_stage:
            output_memory, output_proposals = \
                self.gen_encoder_output_proposals(
                    memory, memory_mask, spatial_shapes)
            enc_outputs_class = self.bbox_head.cls_branches[
                self.decoder.num_layers](
                    output_memory)
            enc_outputs_coord_unact = self.bbox_head.reg_branches[
                self.decoder.num_layers](output_memory) + output_proposals
            enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
            # We only use the first channel in enc_outputs_class as foreground,
            # the other (num_classes - 1) channels are actually not used.
            # Its targets are set to be 0s, which indicates the first
            # class (foreground) because we use [0, num_classes - 1] to
            # indicate class labels, background class is indicated by
            # num_classes (similar convention in RPN).
            # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa
            # This follows the official implementation of Deformable DETR.
            topk_proposals = torch.topk(
                enc_outputs_class[..., 0], self.num_queries, dim=1)[1]
            topk_coords_unact = torch.gather(
                enc_outputs_coord_unact, 1,
                topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
            topk_coords_unact = topk_coords_unact.detach()
            reference_points = topk_coords_unact.sigmoid()
            pos_trans_out = self.pos_trans_fc(
                self.get_proposal_pos_embed(topk_coords_unact))
            pos_trans_out = self.pos_trans_norm(pos_trans_out)
            query_pos, query = torch.split(pos_trans_out, c, dim=2)
        else:
            enc_outputs_class, enc_outputs_coord = None, None
            query_embed = self.query_embedding.weight
            query_pos, query = torch.split(query_embed, c, dim=1)
            query_pos = query_pos.unsqueeze(0).expand(batch_size, -1, -1)
            query = query.unsqueeze(0).expand(batch_size, -1, -1)
            reference_points = self.reference_points_fc(query_pos).sigmoid()

        decoder_inputs_dict = dict(
            query=query,
            query_pos=query_pos,
            memory=memory,
            reference_points=reference_points)
        head_inputs_dict = dict(
            enc_outputs_class=enc_outputs_class,
            enc_outputs_coord=enc_outputs_coord) if self.training else dict()
        return decoder_inputs_dict, head_inputs_dict 
    
    def forward_ori_model(
            self,
            img_feats: Tuple[Tensor],
            batch_data_samples: OptSampleList = None,
        ):
        encoder_inputs_dict, decoder_inputs_dict = self.ori_model.pre_transformer(
            img_feats, batch_data_samples)

        encoder_outputs_dict = self.ori_model.forward_encoder(
            **encoder_inputs_dict)

        tmp_dec_in, head_inputs_dict = self.ori_model.pre_decoder(
            **encoder_outputs_dict)
        decoder_inputs_dict.update(tmp_dec_in)

        # if self.incre_cfg.query_incre.type == 'seperate_queryinit':
        head_inputs_dict['aux_query'] = tmp_dec_in['query'].clone()
        head_inputs_dict['aux_reference'] = tmp_dec_in['reference_points'].clone()
        head_inputs_dict['aux_query_pos'] = tmp_dec_in['query_pos'].clone()

        decoder_outputs_dict = self.ori_model.forward_decoder(**decoder_inputs_dict)
        head_inputs_dict.update(decoder_outputs_dict)      

        head_inputs_dict['ori_hidden_states'] = head_inputs_dict.pop('hidden_states')
        head_inputs_dict['ori_references'] = head_inputs_dict.pop('references')

        return head_inputs_dict,decoder_inputs_dict

    def loss(self, batch_inputs: Tensor,
             batch_data_samples: SampleList) -> Union[dict, list]:

        with torch.no_grad():
            # ori model forward
            ori_img_features = self.ori_model.extract_feat(batch_inputs)

            ori_head_inputs_dict,decoder_inputs_dict = self.forward_ori_model(ori_img_features, batch_data_samples)
            all_layers_ori_cls_scores, all_layers_ori_bbox_preds = \
                self.ori_model.bbox_head(ori_head_inputs_dict['ori_hidden_states'], 
                                        ori_head_inputs_dict['ori_references'])
            ori_head_inputs_dict['all_layers_ori_cls_scores'] = all_layers_ori_cls_scores
            ori_head_inputs_dict['all_layers_ori_bbox_preds'] = all_layers_ori_bbox_preds

            if self.incre_cfg.label_incre.type == 'topk_pseudo' or self.incre_cfg.label_incre.type == 'threshold_pseudo':
                topk_query, batch_pseudo_instances, batch_all_instances = \
                    self.bbox_head.generate_pseudo_label(all_layers_ori_cls_scores,
                                                        all_layers_ori_bbox_preds,
                                                        batch_data_samples
                                                        )    
                
            ori_head_inputs_dict['batch_pseudo_instances'] = batch_pseudo_instances
            ori_head_inputs_dict['batch_all_instances'] = batch_all_instances
            ori_head_inputs_dict['ori_topk_query'] = topk_query

        img_features = self.extract_feat(batch_inputs)

        aux_dict = None
        if self.incre_cfg.query_incre.type == 'seperate_queryinit':
            num_incre_queries = self.incre_cfg.query_incre.num_aux_query
            assert num_incre_queries <= self.incre_cfg.query_incre.num_matching_query
            aux_query = ori_head_inputs_dict['aux_query'].clone()
            ori_query=ori_head_inputs_dict['aux_query'].clone()
            aux_query_pos=ori_head_inputs_dict['aux_query_pos'].clone()
            aux_reference = ori_head_inputs_dict.pop('aux_reference')
            aux_query = aux_query[:, :num_incre_queries]
            aux_reference = aux_reference[:, :num_incre_queries]

            aux_dict = dict(aux_query=aux_query,aux_query_pos=aux_query_pos, 
                            aux_reference=aux_reference, batch_pseudo_instances=batch_pseudo_instances)
        deconf_query = None
        if self.deconf_l:
            (mask_low_conf,
            mask_high_conf_low_iou,
            mask_high_conf_high_iou,
            old_label_local,         # [B, Nq]，0..C_old-1
            label_weights_local,     
            new_label) = \
                self.split_queries_by_conf_and_iou(all_layers_ori_cls_scores,
                                                    all_layers_ori_bbox_preds,
                                                    batch_data_samples)
            deconf_query=self.deconf(ori_query)
            freeze_g2 = True  # 推荐 True：不让这条路径更新 g2 对应的 token
            mask_ori = mask_high_conf_low_iou
            mask_ori = mask_ori.unsqueeze(-1)                       # [B, Nq, 1]
            src = ori_query.detach() if freeze_g2 else ori_query  # 选是否冻结梯度
            deconf_query = torch.where(mask_ori, src, deconf_query)  
        new_head_inputs_dict, old_head_inputs_dict,new_query,old_head_inputs_dict_deconf = self.forward_transformer(img_features,
                                                                        batch_data_samples, aux_dict,deconf_query) 
        # loss_re=None  
        if self.deconf_l:
            decoder_inputs_dict['query']=deconf_query
            decoder_outputs = self.ori_model.forward_decoder(**decoder_inputs_dict)
            output_query=decoder_outputs['hidden_states'][-1]
            ori_cls_scores_deconf, ori_bbox_deconf = \
                    self.ori_model.bbox_head(decoder_outputs['hidden_states'], 
                                            decoder_outputs['references'])
            loss_conf = self.contrastive_pull_together(
                ori_query=ori_query,
                deconf_query=deconf_query,
                mask_low_conf=mask_low_conf,
                mask_high_conf_low_iou=mask_high_conf_low_iou,
                mask_high_conf_high_iou=mask_high_conf_high_iou,
                old_label=old_label_local,
                label_weight=label_weights_local,
                tau=1.
            )    
        if self.deconf_l:
            losses,query_feat_single, ori_query_feat_single = self.bbox_head.loss(new_head_inputs_dict, 
                                        old_head_inputs_dict, 
                                        ori_head_inputs_dict, 
                                        ori_cls_scores_deconf=ori_cls_scores_deconf,
                                        ori_bbox_deconf=ori_bbox_deconf,
                                        old_head_inputs_dict_deconf=old_head_inputs_dict_deconf,
                                        output_query=output_query,
                                        old_label_local=old_label_local,
                                        label_weights_local=label_weights_local,
                                        mask_low_conf=mask_low_conf,
                                        mask_high_conf_low_iou=mask_high_conf_low_iou,
                                        mask_high_conf_high_iou=mask_high_conf_high_iou,
                                         query_deconf=deconf_query,
                                        batch_data_samples=batch_data_samples)
        else:
            losses,query_feat_single, ori_query_feat_single = self.bbox_head.loss(new_head_inputs_dict, 
                            old_head_inputs_dict, 
                            ori_head_inputs_dict, 
                            batch_data_samples=batch_data_samples)
        rank_sim_loss=None
        infoNCE_loss=None
        if self.deconf_l and len(ori_query_feat_single) > 0 and len(ori_query_feat_single) > 0: # inter_query
            rank_sim_loss = 6*self.loss_rank_kd(query_feat_single, ori_query_feat_single)
        if self.deconf_l:
            losses['loss_confMlp'] = loss_conf
        if rank_sim_loss is not None:
            losses['loss_rank_sim'] = rank_sim_loss
        return losses
    
    def split_queries_by_conf_and_iou(
        self,
        all_layers_ori_cls_scores,      # List[L]，[B, Nq, C_old]
        all_layers_ori_bbox_preds,      # List[L]，[B, Nq, 4] (cxcywh, norm)
        batch_data_samples,
        iou_thr_share=0.7,
        score_thr_old=0.4,
        use_dynamic_old_thr=True,
        dynamic_sigma_k=2.0,

        topk_limit=300, 
    ):
        import torch

        cls_scores = all_layers_ori_cls_scores[-1]     # [B, Nq, C_old]
        bbox_preds = all_layers_ori_bbox_preds[-1]     # [B, Nq, 4]
        B, Nq, C_old = cls_scores.shape
        device = cls_scores.device
        dtype  = cls_scores.dtype


        metas = [s.metainfo for s in batch_data_samples]
        gts   = [s.gt_instances for s in batch_data_samples]

        factors = []
        for meta, bp in zip(metas, bbox_preds):
            h, w = meta['img_shape'][:2]
            factors.append(bp.new_tensor([w, h, w, h]).unsqueeze(0).repeat(Nq, 1))
        factors = torch.stack(factors, 0)                         # [B, Nq, 4]
        boxes_xyxy = bbox_cxcywh_to_xyxy(bbox_preds) * factors    # [B, Nq, 4]
        for b, meta in enumerate(metas):
            h, w = meta['img_shape'][:2]
            boxes_xyxy[b, :, 0::2].clamp_(0, w)
            boxes_xyxy[b, :, 1::2].clamp_(0, h)

        iou_max_list, newlab_list = [], []
        for b, inst in enumerate(gts):
            if len(inst) == 0 or inst.bboxes.numel() == 0:
                iou_max_list.append(torch.zeros(Nq, device=device, dtype=dtype))
                newlab_list.append(torch.full((Nq,), -1, device=device, dtype=torch.long))
                continue
            ious = bbox_overlaps(boxes_xyxy[b], inst.bboxes)      # [Nq, M_new]
            iou_max, iou_argmax = ious.max(dim=1)                 # [Nq], [Nq]
            new_lab = inst.labels[iou_argmax]                     # [Nq]
            iou_max_list.append(iou_max)
            newlab_list.append(new_lab)
        iou_max   = torch.stack(iou_max_list, 0)                  # [B, Nq]
        new_label = torch.stack(newlab_list, 0)                   # [B, Nq]


        old_scores = cls_scores.reshape(-1, C_old).sigmoid().reshape(B, Nq, -1)                                   # [B, Nq, C_old]
        old_max_score, old_argmax_local = old_scores.max(dim=-1)  # [B, Nq], [B, Nq]

        if use_dynamic_old_thr:
            mu  = old_max_score.mean(dim=1, keepdim=True)                         # [B,1]
            sig = old_max_score.std(dim=1, keepdim=True, unbiased=False)          # [B,1]
            conf_thr = mu + dynamic_sigma_k * sig   
            conf_thr_med =mu + sig                              # [B,1]
            conf_thr = torch.nan_to_num(conf_thr, nan=float('inf'),
                                        posinf=float('inf'), neginf=0.0)
            mask_high_conf_base = (old_max_score > conf_thr)                      # [B,Nq]
            mask_low_conf_base = (old_max_score < mu)
            mask_med_conf_base = (old_max_score >conf_thr_med)
        else:
            conf_thr = torch.full((B,1), float(score_thr_old), device=device, dtype=dtype)
            mask_high_conf_base = (old_max_score >= conf_thr)

        mask_topk = torch.zeros((B, Nq), dtype=torch.bool, device=device)         # [B,Nq]
        K = min(int(topk_limit), Nq)
        if K > 0:

            topk_idx = torch.topk(old_max_score, k=K, dim=1, largest=True, sorted=False).indices  # [B,K]
            mask_topk.scatter_(1, topk_idx, True)  
 
        mask_high_conf = mask_topk & mask_high_conf_base

        mask_high_conf_high_iou = (iou_max >= iou_thr_share)
        mask_high_conf_low_iou  = mask_high_conf & (iou_max <  iou_thr_share)
        mask_low_conf           = mask_topk & (mask_low_conf_base)& (iou_max <  iou_thr_share)

        return (
            mask_low_conf,                
            mask_high_conf_low_iou,       
            mask_high_conf_high_iou,    
            old_argmax_local,              
            old_max_score,             
            new_label                 
        )
