# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import os, xml.etree.ElementTree as ET
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Linear
from mmengine.model import constant_init
from mmengine.structures import InstanceData
from torch import Tensor
from mmengine import Config
from mmdet.models.losses import QualityFocalLoss
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import SampleList
from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, bbox_overlaps
from mmdet.utils import InstanceList, ConfigType, reduce_mean,  OptInstanceList
from ..utils import multi_apply
from ..layers import inverse_sigmoid
from mmdet.structures.bbox import bbox2roi
# from mmdet.models.layers.deconfMLP import DeconfMLP
from .atss_vlfusion_head import convert_grounding_to_cls_scores
# from .gdino_head_inc_distn import GroundingDINOHead_inc_distn
from .incre_ddetr_head import incre_DeformableDETRHead
from mmdet.models.losses.distn_loss import inter_query_relation, inter_text_relation_partial, query_feat,detr_query_feat, text_feat,new_text_feat

@MODELS.register_module()
class incre_gcd_DeformableDETRHead(incre_DeformableDETRHead):
    """Head of the Grounding DINO: Marrying DINO with Grounded Pre-Training for
    Open-Set Object Detection.

    Args:
        contrastive_cfg (dict, optional): Contrastive config that contains
          keys like ``max_text_len``. Defaults to dict(max_text_len=256).
    """

    def __init__(self, distn_cfg, **kwargs):
        super().__init__(**kwargs)

    @torch.no_grad()
    def generate_pseudo_label(self, all_layers_cls_scores, all_layers_bbox_preds, 
                               batch_data_samples, 
                               test=None, mode='None'):
        batch_img_metas = []
        batch_gt_instances = []
        for data_sample in batch_data_samples:
            batch_img_metas.append(data_sample.metainfo)
            batch_gt_instances.append(data_sample.gt_instances)
        last_layer_ori_cls_scores = all_layers_cls_scores[-1]
        last_layer_ori_bbox_preds = all_layers_bbox_preds[-1]

        B,L,_ = last_layer_ori_cls_scores.size()

        topk_query, batch_pseudo_instances, self.batch_prob_maxval, self.batch_prob_maxidx = \
            multi_apply(self.generate_pesudo_label_single,
                        last_layer_ori_cls_scores,
                        last_layer_ori_bbox_preds,
                        batch_gt_instances,
                        batch_data_samples,
                        batch_img_metas,
                        test=test)
        
        batch_all_instances = []
        for gt_instances, pseudo_instances in zip(batch_gt_instances, batch_pseudo_instances): 
            num_gt     = gt_instances.bboxes.size(0)
            num_pseudo = pseudo_instances.bboxes.size(0)

            if hasattr(pseudo_instances, 'bbox_weight'):   
                pseudo_bbox_w = pseudo_instances.bbox_weight
            else:                                   
                pseudo_bbox_w = torch.ones(num_pseudo, device=gt_instances.bboxes.device)

            if hasattr(pseudo_instances, 'cls_weight'):
                pseudo_cls_w  = pseudo_instances.cls_weight
            else:
                pseudo_cls_w  = torch.ones(num_pseudo, device=gt_instances.bboxes.device)
            if hasattr(pseudo_instances, 'iou_weight'):
                pseudo_iou_weight  = pseudo_instances.iou_weight
            else:
                pseudo_iou_weight  = torch.ones(num_pseudo, device=gt_instances.bboxes.device)
            
            gt_bbox_w = torch.ones(num_gt, device=gt_instances.bboxes.device)
            gt_cls_w  = torch.ones(num_gt, device=gt_instances.bboxes.device)

            all_bbox_weights = torch.cat([gt_bbox_w,  pseudo_bbox_w], 0)
            all_cls_weights  = torch.cat([gt_cls_w,   pseudo_cls_w],  0)                       
            all_bboxes = torch.cat((gt_instances.bboxes, pseudo_instances.bboxes), 0)
            all_labels = torch.cat((gt_instances.labels, pseudo_instances.labels), 0)
            all_iou_weight= torch.cat([gt_bbox_w,pseudo_iou_weight],0)           
            all_instances = InstanceData(labels=all_labels, bboxes=all_bboxes,
                                        bbox_weights= all_bbox_weights,cls_weights=all_cls_weights,
                                        iou_weight=all_iou_weight)
            batch_all_instances.append(all_instances)

        return topk_query, batch_pseudo_instances, batch_all_instances
    @torch.no_grad()
    def generate_pesudo_label_single(self, ori_cls_scores: Tensor,
                                    ori_bbox_preds: Tensor, gt_instances: InstanceData,
                                    batch_data_samples,
                                    img_metas: dict, test=False):
        gt_bboxes,gt_labels = gt_instances.bboxes, gt_instances.labels 
        img_h, img_w = img_metas['img_shape']
        factor = ori_bbox_preds.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0)
        num_bboxes = ori_bbox_preds.size(0)  # [900, 4]
        self.thre_giou=0.0
        if gt_labels.size(0) !=0 and gt_labels.max() >= self.trunc_class[0]:    # not empty sample        
            # convert positive_map to label
            ori_output_score = ori_cls_scores.sigmoid()       # [900, num_cls]
            prob_maxval, prob_maxidx = ori_output_score.max(dim=-1)
            if self.distn_cfg.label_distn.type == 'topk_pseudo':  # topk select
                topk_value, topk_idx = torch.topk(prob_maxval, self.sigma, dim=-1)
            elif self.distn_cfg.label_distn.type == 'threshold_pseudo':     # threshold select
                topk_idx = torch.where(prob_maxval >= self.sigma)[0]                                       
            else:
                raise ValueError('not implement')
            # generate pseudo_label
            ref_labels = prob_maxidx[topk_idx]

            if self.distn_cfg.label_distn.mode == 'response':
                ref_positive_map = torch.zeros_like(ori_cls_scores[topk_idx])   # [topk, 256]
                valid_len = ori_cls_scores.size(1) - 1    # last pos mean EOS
                ref_positive_map[:, :valid_len] = ori_cls_scores[topk_idx][:, :valid_len]
            else:   # hardlabel
                ref_positive_map = torch.zeros_like(ori_cls_scores[topk_idx])   # [topk, 256]
                for i in range(ref_labels.size(0)):
                    positive_idx = ref_labels[i].item()
                    ref_positive_map[i, positive_idx] = 1   # hard pseudo label                   

            ref_box_all = ori_bbox_preds  # [all, 4]
            ref_box_list = ref_box_all[topk_idx]  # [topk, 4]

            # Apply factor scaling
            ref_box_all = bbox_cxcywh_to_xyxy(ref_box_all)
            ref_box_list = bbox_cxcywh_to_xyxy(ref_box_list)
            ref_box_all = ref_box_all * factor
            ref_box_list = ref_box_list * factor
            ref_box_all[:, 0::2].clamp_(min=0, max=img_w)
            ref_box_all[:, 1::2].clamp_(min=0, max=img_h)
            ref_box_list[:, 0::2].clamp_(min=0, max=img_w)
            ref_box_list[:, 1::2].clamp_(min=0, max=img_h)

            # Compute IOU once for all bbox predictions (avoid redundancy)
            iou_all = bbox_overlaps(ref_box_all, gt_bboxes)  # [num_queries, num_gt]
            ioumax, _ = iou_all.max(dim=1)  # max iou for each query
            ioumax_val = ioumax[topk_idx]  # Get the corresponding IOU for topk indices
            # avoid overlap with gt
            gt_include_list = torch.where(ioumax_val>=self.label_iou_th, False, True)  #self.label_iou_th
            pseudo_bboxes = ref_box_list[gt_include_list]
            pseudo_positive_maps = ref_positive_map[gt_include_list]
            pseudo_labels = ref_labels[gt_include_list]

            if pseudo_bboxes.dtype == torch.float16:
                pseudo_bboxes = pseudo_bboxes.to(torch.float32)

            topk_idx = topk_idx[gt_include_list]
                
            gt_instances = InstanceData(labels=pseudo_labels, positive_maps=pseudo_positive_maps, bboxes=pseudo_bboxes)
        else:
            topk_idx = None
            prob_maxval = ori_cls_scores.new_tensor([])
            prob_maxidx = ori_cls_scores.new_tensor([])
        
        return topk_idx, gt_instances, prob_maxval, prob_maxidx

    def loss(self, new_head_inputs_dict, #new
             old_head_inputs_dict,   #old query old reference 过new model
             ori_head_inputs_dict, #old
             ori_cls_scores_deconf=None,
             ori_bbox_deconf=None,
             old_head_inputs_dict_deconf=None,
             output_query=None,
             old_label_local=None,
             label_weights_local=None,
             mask_low_conf=None,
             mask_high_conf_low_iou=None,
             mask_high_conf_high_iou=None,
             query_deconf=None,
             batch_data_samples=None) -> dict:       
         
        batch_gt_instances = []
        batch_img_metas = []
        for data_sample in batch_data_samples:
            batch_img_metas.append(data_sample.metainfo)
            batch_gt_instances.append(data_sample.gt_instances)

        self.ori_topk_query = ori_head_inputs_dict['ori_topk_query']

        new_outs = self(new_head_inputs_dict['hidden_states'], 
                        new_head_inputs_dict['references'])  
        
        old_outs = self(old_head_inputs_dict['hidden_states'], 
                        old_head_inputs_dict['references'])  
        
        new_enc_cls_scores = new_head_inputs_dict['enc_outputs_class']
        new_enc_bbox_preds = new_head_inputs_dict['enc_outputs_coord']



        hidden_states = old_head_inputs_dict['hidden_states']
        new_hidden_states = new_head_inputs_dict['hidden_states']

        all_layers_ori_cls_scores = ori_head_inputs_dict['all_layers_ori_cls_scores']
        all_layers_ori_bbox_preds = ori_head_inputs_dict['all_layers_ori_bbox_preds']
        if 'batch_pseudo_instances' in ori_head_inputs_dict.keys():
            batch_pseudo_instances = ori_head_inputs_dict['batch_pseudo_instances']
            batch_all_instances = ori_head_inputs_dict['batch_all_instances']
        else:
            batch_pseudo_instances = None
            batch_all_instances = None

        detr_loss_inputs = new_outs + (new_enc_cls_scores, new_enc_bbox_preds,new_hidden_states,
                                       batch_gt_instances, batch_img_metas, batch_all_instances)+\
                                       (all_layers_ori_cls_scores, all_layers_ori_bbox_preds,
                                         batch_pseudo_instances, ori_hidden_states,
                                         )
        
        distn_loss_inputs = old_outs + (batch_gt_instances, batch_img_metas, hidden_states) + \
                                        (all_layers_ori_cls_scores, all_layers_ori_bbox_preds,
                                         batch_pseudo_instances, batch_all_instances, ori_hidden_states,
                                         output_query,ori_cls_scores_deconf,new_outs[0],new_hidden_states)
        loss_dict_old,query_feat_single, ori_query_feat_single = self.loss_by_feat_old(*distn_loss_inputs)
        loss_dict_new = self.loss_by_feat_new(*detr_loss_inputs)
        loss_distill_deconf =None
        if query_deconf is not None:
            (loss_distill_deconf,) =multi_apply(
                self.loss_by_cls_ld_distn_single,
                ori_cls_scores_deconf,
                old_outs[0],
            )
        if loss_align_deconf is not None:
            loss_dict_new['loss_align_deconf']=10*loss_align_deconf
        if loss_distill_deconf is not None:
            loss_dict_new['loss_distill_deconf']=loss_distill_deconf[-1]
            num_dec_layer = 0
            for ld_loss_cls_i in loss_distill_deconf[:-1]:
                loss_dict_new[f'd{num_dec_layer}.loss_distill_deconf'] = ld_loss_cls_i
                num_dec_layer += 1
        return loss_dict_new,query_feat_single, ori_query_feat_single  
    
    def loss_by_feat_old(
        self,
        all_layers_cls_scores: Tensor,    #old query ->new model
        all_layers_bbox_preds: Tensor,
        batch_gt_instances: InstanceList,
        batch_img_metas: List[dict],
        hidden_states: Tensor = None,
        all_layers_ori_cls_scores: Tensor = None, # original model input for distn
        all_layers_ori_bbox_preds: Tensor = None, # original model input for distn
        batch_pseudo_instances: OptInstanceList = None,
        batch_all_instances: OptInstanceList = None,
        ori_hidden_states: Tensor = None,
        output_query: Tensor = None,
        ori_cls_scores_deconf: Tensor = None,
        all_layers_new_cls_scores: Tensor=None,
        new_hidden_states: Tensor=None,
        batch_gt_instances_ignore: OptInstanceList = None,
    ) -> Dict[str, Tensor]:

        loss_dict = dict()
        self.start=self.trunc_class[0]
        self.end=self.trunc_class[1]
        if batch_pseudo_instances is None:
            batch_all_instances = batch_gt_instances
        
        # ===== logits distn loss =====    
        ld_losses_cls, ld_losses_bbox, ld_losses_iou,ld_losses_cls_bg,valid_mask_list = multi_apply(
            self.loss_by_feat_ld_distn_single,
            all_layers_aux_cls_scores,
            all_layers_aux_bbox_preds,
            all_layers_ori_cls_scores,
            all_layers_ori_bbox_preds,
            ori_cls_scores_deconf,
            # batch_gt_instances=batch_all_instances,
            batch_gt_instances=batch_gt_instances,
            batch_pseudo_instances=batch_pseudo_instances,
            batch_img_metas=batch_img_metas) 
            
        # loss from the last decoder layer        
        loss_dict['loss_ld_cls'] = ld_losses_cls[-1]
        loss_dict['loss_ld_bbox'] = ld_losses_bbox[-1]
        loss_dict['loss_ld_iou'] = ld_losses_iou[-1]
        # loss_dict['loss_ld_cls_bg'] = ld_losses_cls_bg[-1] 
        layer_id=5
        query_feat_single, ori_query_feat_single=None,None
        if output_query is not None:
            query_feat_single, ori_query_feat_single=\
                self.rank_sim_distn_single(new_hidden_states[-1], output_query.detach(),
                                                all_layers_new_cls_scores[-1], all_layers_matching_bbox_preds[-1], 
                                                ori_cls_scores_deconf[-1], all_layers_ori_bbox_preds[-1],valid_mask_list,
                                                layer_id, 
                                                batch_pseudo_instances, batch_all_instances, batch_img_metas)
        return loss_dict,query_feat_single, ori_query_feat_single
    
    def loss_by_feat_new(
        self,
        all_layers_cls_scores: Tensor,
        all_layers_bbox_preds: Tensor,
        new_enc_cls_scores: Tensor,
        new_enc_bbox_preds: Tensor,
        new_hidden_states: Tensor,
        batch_gt_instances: InstanceList,
        batch_img_metas: List[dict],
        batch_all_instances: InstanceList,
        all_layers_ori_cls_scores: Tensor = None, # original model input for distn
        all_layers_ori_bbox_preds: Tensor = None, # original model input for distn
        batch_pseudo_instances: OptInstanceList = None,
        ori_hidden_states: Tensor = None,
    ) -> Dict[str, Tensor]:
       
        loss_dict = dict()
        if batch_pseudo_instances is None:
            batch_all_instances = batch_gt_instances
        # extract denoising and matching part of outputs
        loss_dict = super().loss_by_feat(
        all_layers_cls_scores,
        all_layers_bbox_preds,
        new_enc_cls_scores,
        new_enc_bbox_preds,
        batch_all_instances,
        batch_img_metas) 
        return loss_dict

    def loss_by_feat_ld_distn_single(self, cls_scores: Tensor, bbox_preds: Tensor,
                                    ori_cls_scores: Tensor, ori_bbox_preds: Tensor,
                                    ori_cls_scores_deconf: Tensor,
                                    batch_pseudo_instances: InstanceList,
                                    batch_gt_instances: InstanceList,
                                    batch_img_metas: List[dict],
                                    weighted=True) -> Tuple[Tensor]:
        
        # construct factors used for rescale bboxes
        factors = []
        for img_meta, bbox_pred in zip(batch_img_metas, bbox_preds):
            img_h, img_w, = img_meta['img_shape']
            factor = bbox_pred.new_tensor([img_w, img_h, img_w,
                                           img_h]).unsqueeze(0).repeat(
                                               bbox_pred.size(0), 1)
            factors.append(factor)
        factors = torch.cat(factors, 0)

        # filter ori_bboxes overlaped with gt
        overlap_list = []
        overlap_list_bg=[]
        for instance, ori_bboxes, factor, bboxes, img_metas in zip(batch_gt_instances, ori_bbox_preds, factors, bbox_preds,
                                                           batch_img_metas):
            # instance, ori_bboxes, factor, img_metas=batch_gt_instances[0], ori_bbox_preds[0], factors[0], batch_img_metas[0]
            if len(instance.labels) > 0:
                img_h, img_w, = img_meta['img_shape']
                gt_bboxes = instance.bboxes
                ori_bboxes = bbox_cxcywh_to_xyxy(ori_bboxes)
                ori_bboxes = ori_bboxes * factor         
                ori_bboxes[:, 0::2].clamp_(min=0, max=img_w)
                ori_bboxes[:, 1::2].clamp_(min=0, max=img_h)
                iou_list = bbox_overlaps(ori_bboxes, gt_bboxes)
                ioumax_val, ioumax_idx = torch.max(iou_list, dim=1)
                invalid_bbox = torch.where(ioumax_val>self.label_iou_th, True, False)
                invalid_bg=torch.where(ioumax_val>0.5, True, False)
                overlap_list.append(invalid_bbox)
                overlap_list_bg.append(invalid_bg)
            else:
                overlap_list.append(ori_bboxes.new_ones(ori_bboxes.size(0)).bool())
                overlap_list_bg.append(ori_bboxes.new_ones(ori_bboxes.size(0)).bool())

        overlap_list = torch.stack(overlap_list, dim=0)    
        overlap_list_bg = torch.stack(overlap_list_bg, dim=0)

        max_scores, max_ids = ori_cls_scores.sigmoid().max(dim=-1)    # [B,900]   
        max_scores_deconf, _ = ori_cls_scores_deconf.sigmoid().max(dim=-1) 
        
        if weighted:       
            label_weights = torch.zeros_like(max_scores)
            label_weights_deconf = torch.zeros_like(max_scores_deconf)
            bbox_weights = torch.zeros_like(max_scores)    
            valid_mask_list = torch.zeros_like(max_scores) 
            valid_mask_list_deconf = torch.zeros_like(max_scores_deconf)  
            for i, (score,score_deconf) in enumerate(zip(max_scores,max_scores_deconf)):
                cls_thr = score.mean() + 2 * score.std()
                cls_thr_deconf =score_deconf.mean()+2*score_deconf.std()
                valid_mask = (score > cls_thr)
                valid_mask_deconf=(score_deconf>cls_thr_deconf)
                # valid_mask_bg = ((~overlap_bg)&(score<=cls_thr))            
                label_weights[i][valid_mask] = 1.0
                # label_weights_bg[i][valid_mask_bg] = 1.0
                bbox_weights[i][valid_mask] = 1.0
                valid_mask_list[i] = valid_mask
                # valid_mask_list_bg[i] = valid_mask_bg
            label_weights[overlap_list] = 0.0  
            bbox_weights[overlap_list] = 0.0
            valid_mask_list[overlap_list] = 0.0
            # valid_mask_list_bg[overlap_list_bg] = 0.0
        else:
            label_weights = torch.ones_like(max_scores)
            bbox_weights = torch.ones_like(max_scores)    
            valid_mask_list = torch.ones_like(max_scores)    
            label_weights[overlap_list] = 0.0  
            bbox_weights[overlap_list] = 0.0
            valid_mask_list[overlap_list] = 0.0   
        
        cls_scores = cls_scores[..., :self.start].contiguous()

        ori_cls_scores = ori_cls_scores.contiguous()

        num_total_pos = valid_mask_list.sum()
        # num_total=valid_mask_list_bg.sum()
        # num_total2=valid_mask_list_bg2.sum()
        cls_avg_factor = num_total_pos * 1.0

        if self.sync_cls_avg_factor:
            cls_avg_factor = reduce_mean(
                cls_scores.new_tensor([cls_avg_factor]))
            # cls_avg_factor_all = reduce_mean(
            #     cls_scores.new_tensor([cls_avg_factor_all]))
        cls_avg_factor = max(cls_avg_factor, 1)
        # cls_avg_factor_all = max(cls_avg_factor_all, 1)
        # cls_avg_factor_all2 = max(cls_avg_factor_all2, 1)

        # valid_ori_text_len = self.ori_text_masks[0].sum()

        if self.distn_cfg.label_distn.loss_ld.type == 'L2Loss':
            label_weights = label_weights[..., None].repeat(1, 1, ori_cls_scores.size(-1))
            loss_cls = self.loss_ld(cls_scores, ori_cls_scores, label_weights, avg_factor=cls_avg_factor)
            
        elif self.distn_cfg.label_distn.loss_ld.type == 'KnowledgeDistillationKLDivLoss':
            cls_scores = cls_scores.view(-1,self.start)  # [batch * num_query, valid_ori_text_len]
            labels = ori_cls_scores .view(-1,self.start)
            loss_cls = self.loss_ld(cls_scores, labels, label_weights.flatten(), avg_factor=cls_avg_factor) 

            loss_cls_bg=None            
 
        num_total_pos = loss_cls.new_tensor([num_total_pos])
        num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()


        if ori_bbox_preds.dtype == torch.float16:
            ori_bbox_preds = ori_bbox_preds.to(torch.float32)

        bbox_preds = bbox_preds.reshape(-1, 4)
        ori_bbox_preds = ori_bbox_preds.reshape(-1, 4)
        bbox_weights = bbox_weights.unsqueeze(-1).repeat(1, 1, bbox_preds.size(-1))
        bbox_weights = bbox_weights.reshape(-1, 4)
        bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
        bboxes_gt = bbox_cxcywh_to_xyxy(ori_bbox_preds) * factors

        loss_iou = self.loss_iou(
            bboxes, bboxes_gt, bbox_weights, avg_factor=cls_avg_factor)
        
        # regression L1 loss
        loss_bbox = self.loss_bbox(
            bbox_preds, ori_bbox_preds, bbox_weights, avg_factor=cls_avg_factor)

        return loss_cls, loss_bbox, loss_iou,loss_cls_bg,valid_mask_list #,avg_factor

    
    def rank_sim_distn_single(self, batch_query_feats, batch_ori_query_feats,
                                    batch_cls_scores, batch_bbox_preds, 
                                    batch_ori_cls_scores, batch_ori_bbox_preds, 
                                    valid_score_mask,layer_id,
                                    batch_pseudo_instances, batch_all_instances, batch_img_metas):
        ori_pseudo_labels_list = []
        all_labels_list = []

        for pseudo_instance in batch_pseudo_instances:
            ori_pseudo_labels_list.append(pseudo_instance.labels)
        unique_pseudo_labels, unique_pseudo_counts = torch.unique(torch.cat(ori_pseudo_labels_list, dim=0), return_counts=True)

        for all_instances in batch_all_instances:
            all_labels_list.append(all_instances.labels)
        unique_labels = torch.unique(torch.cat(all_labels_list, dim=0))

        # norm_text_diff_matrix = []
        # norm_query_diff_matrix = []
        ori_query_feat_single = []
        # intra_distance = []
        query_feat_single, ori_query_feat_single = \
            detr_query_feat(unique_pseudo_labels, 
                                batch_cls_scores, batch_ori_cls_scores,
                                batch_query_feats, batch_ori_query_feats,ori_pseudo_labels_list,unique_pseudo_counts,valid_score_mask) 
        return query_feat_single, ori_query_feat_single