from time import time
import torch
import torch.distributed
import torch.nn.functional as F
from torch import nn

from utils.box_utils import generalized_box_iou, box_cxcywh_to_xyxy, box_iou, bbox_inner_iou
from utils.comm import is_dist_avail_and_initialized, get_world_size


class VideoSTGLoss(nn.Module):
    """This class computes the loss for VideoSTG Model
    The process happens in two steps:
        1) compute ground truth boxes and the outputs of the model
        2) compute ground truth temporal segment and the outputs sted of model
    """

    def __init__(self, cfg, losses):
        """Create the criterion.
        """
        super().__init__()
        self.cfg = cfg
        self.losses = losses
        self.eos_coef = cfg.SOLVER.EOS_COEF
    
    def loss_boxes(self, outputs, targets, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
        targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
        The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
        """
        assert "pred_boxes" in outputs
        
        src_boxes = outputs["pred_boxes"]
        target_boxes = torch.cat([target["boxs"].bbox for target in targets], dim=0)
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")

        losses = {}
        losses["loss_bbox"] = loss_bbox.sum() / max(num_boxes, 1)

        # loss_giou = 1 - torch.diag(generalized_box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes)))
        # losses["loss_giou"] = loss_giou.sum() / max(num_boxes, 1)
        # print(src_boxes.shape, target_boxes.shape)
        # print(loss_giou.shape)
        # print("loss_giou:", loss_giou)
        # print(losses["loss_giou"])

        loss_innergiou = 1 - bbox_inner_iou(src_boxes, target_boxes, xywh=True, GIoU=True, ratio=1.15).squeeze(-1)  # Inner IoU
        losses["loss_giou"] = loss_innergiou.sum() / max(num_boxes, 1)
        # print(loss_innergiou.shape)
        # print("loss_innergiou:", loss_innergiou)
        # print(losses["loss_innergiou"])

        return losses

    def loss_img_attn(self, outputs, targets, num_boxes, ):
        assert "imgweights" in outputs
        losses = {}
        spatial_shapes = outputs["spatial_shapes"]
        imgweights = outputs["imgweights"].squeeze(1)  # [BT, 1, HW]

        target_boxes = torch.cat([target["boxs"].bbox for target in targets], dim=0)
        img_sizes = targets[0]["img_size"]

        # TODO: 计算KL散度
        bboxs = box_cxcywh_to_xyxy(target_boxes).clamp(min=0).squeeze() * torch.Tensor(
            [img_sizes[1], img_sizes[0], img_sizes[1], img_sizes[0]]).to(target_boxes.device)
        bboxs = torch.stack([(bboxs[:, 0]).round(), (bboxs[:, 1]).round(), (bboxs[:, 2]).ceil(), (bboxs[:, 3]).round()],
                            dim=-1).long()  # torch.round(bboxs).int()
        # 创建一个与输入图像大小相同的掩码
        mask = torch.zeros((imgweights.shape[0], img_sizes[0], img_sizes[1]), dtype=torch.float32)
        # 遍历每个 bounding box，填充掩码
        for i, bbox in enumerate(bboxs):
            x1, y1, x2, y2 = bbox.int()
            mask[i, y1:y2, x1:x2] = 1.0
        # 调整掩码的分辨率以匹配 U
        size = spatial_shapes[0]
        for spatial_shape in spatial_shapes:
            if imgweights.shape[1] == spatial_shape[0]* spatial_shape[1]:
                size = spatial_shape
        resized_mask = F.interpolate(mask.unsqueeze(1), size=size, mode='nearest').squeeze(1).flatten(1)

        temperature=0.2
        true_proba = F.softmax(resized_mask/temperature, dim=1).to(imgweights.device)
        preweight_proba = F.softmax(imgweights/temperature, dim=1)

        loss_KL = torch.sum(true_proba * (true_proba / preweight_proba).log(), dim=1)

        losses["imgguided_attn"] = loss_KL.sum() / max(num_boxes, 1)

        return losses

    def loss_conf(self, outputs, targets, num_boxes, gt_index):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
        targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
        The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
        """
        assert "boxes_conf" in outputs
        losses = {}
        src_boxes = outputs["pred_boxes"]
        target_boxes = torch.cat([target["boxs"].bbox for target in targets], dim=0)
        iou, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))
        iou = torch.diag(iou)
        conf = outputs['boxes_conf'][gt_index]
        # v1 and v2
        losses["loss_conf"] = nn.BCEWithLogitsLoss()(conf, iou)
        # vid
        # losses["loss_conf"] = F.smooth_l1_loss(conf, iou, reduction="none").sum() / max(num_boxes, 1)
        return losses

    def loss_actioness(self, outputs, targets, gt_temp_bound, time_mask=None):
        assert "pred_actioness" in outputs
        losses = {}
        pred_actioness = outputs['pred_actioness'].squeeze(-1)
        target_actioness = torch.stack([target["actioness"] for target in targets], dim=0).float()
        weight = torch.full(pred_actioness.shape, self.eos_coef, device=pred_actioness.device)
        
        for i_b in range(len(weight)):
            temp_bound = gt_temp_bound[i_b]
            weight[i_b][temp_bound[0] : temp_bound[1] + 1] = 1
    
        loss_actioness = F.binary_cross_entropy_with_logits(pred_actioness, \
                target_actioness, weight=weight, reduction='none')
        
        loss_actioness = loss_actioness * time_mask
        losses["loss_actioness"] = loss_actioness.mean()
        return losses

    def loss_sted(self, outputs, num_boxes, gt_temp_bound, positive_map, time_mask=None):
        assert "pred_sted" in outputs
        sted = outputs["pred_sted"]
        losses = {}
        
        target_start = torch.tensor([x[0] for x in gt_temp_bound], dtype=torch.long).to(sted.device)
        target_end = torch.tensor([x[1] for x in gt_temp_bound], dtype=torch.long).to(sted.device)
        sted = sted.masked_fill(~time_mask[:, :, None], -1e32)  # put very low probability on the padded positions before softmax
        eps = 1e-6
        
        sigma = self.cfg.SOLVER.SIGMA
        start_distrib = (
            -(
                (
                    torch.arange(sted.shape[1])[None, :].to(sted.device)
                    - target_start[:, None]
                )
                ** 2
            )
            / (2 * sigma ** 2)
        ).exp()  # gaussian target
        start_distrib = F.normalize(start_distrib + eps, p=1, dim=1)
        pred_start_prob = (sted[:, :, 0]).softmax(1)
        loss_start = (
            pred_start_prob * ((pred_start_prob + eps) / start_distrib).log()
        )
        loss_start = loss_start * time_mask
        end_distrib = (
            -(
                (
                    torch.arange(sted.shape[1])[None, :].to(sted.device)
                    - target_end[:, None]
                )
                ** 2
            )
            / (2 * sigma ** 2)
        ).exp()  # gaussian target
        end_distrib = F.normalize(end_distrib + eps, p=1, dim=1)
        pred_end_prob = (sted[:, :, 1]).softmax(1)
        loss_end = (
            pred_end_prob * ((pred_end_prob + eps) / end_distrib).log()
        )
        loss_end = loss_end * time_mask
        loss_sted = loss_start + loss_end
        losses["loss_sted"] = loss_sted.mean()
        return losses

    def loss_guided_attn(
        self, outputs, num_boxes, gt_temp_bound, positive_map, time_mask=None
    ):
        """Compute guided attention loss
        targets dicts must contain the key "weights" containing a tensor of attention matrices of dim [B, T, T]
        """
        weights = outputs["weights"]  # BxTxT
        
        positive_map = positive_map + (~time_mask)  # the padded positions also have to be taken out
        eps = 1e-6  # avoid log(0) and division by 0

        loss = -(1 - weights + eps).log()
        loss = loss.masked_fill(positive_map[:, :, None], 0)
        nb_neg = (~positive_map).sum(1) + eps
        loss = loss.sum(2) / nb_neg[:, None]  # sum on the column
        loss = loss.sum(1)  # mean on the line normalized by the number of negatives
        loss = loss.mean()  # mean on the batch
        
        losses = {"loss_guided_attn": loss}
        return losses

    def Crossmodal_constrastive_loss(self,flang_attn, region_feature, neg_region_feature, T=0.07):
        losses = {}
        k = region_feature
        q = nn.functional.normalize(flang_attn, dim=1)
        # k = nn.functional.normalize(k, dim=1)
        neg = neg_region_feature.permute(0, 2, 1)
        neg = nn.functional.normalize(neg, dim=1)
        l_neg = torch.einsum('nc,nck->nk', [q, neg])
        celoss = torch.nn.CrossEntropyLoss(size_average=True)
        temp_loss = 0
        for jj in range(k.shape[1]):
            temp_k = k[:, jj, :]
            temp_k= nn.functional.normalize(temp_k, dim=1)
            l_pos = torch.einsum('nc,nc->n', [q, temp_k]).unsqueeze(-1)
            logits = torch.cat([l_pos, l_neg], dim=1)
            logits /= T
            labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
            temp_loss = celoss(logits, labels) + temp_loss
        losses["Crossmodal"] = temp_loss * 1.0 / k.shape[1]
        return losses

    def Interframe_contrastive_loss(self,frame_feature, corre_feature, neg_feature,  T=0.07):
        losses = {}
        q = nn.functional.normalize(frame_feature.squeeze(1), dim=1)
        k = nn.functional.normalize(corre_feature.squeeze(1), dim=1)
        neg = neg_feature.permute(0, 2, 1)
        neg = nn.functional.normalize(neg, dim=1)
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)

        l_neg = torch.einsum('nc,nck->nk', [q, neg])

        logits = torch.cat([l_pos, l_neg], dim=1)

        logits /= T
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        celoss = torch.nn.CrossEntropyLoss(size_average=True)
        losses["interframe"] = celoss(logits, labels)
        return losses
    # def get_loss(
    #     self, loss, outputs, targets, num_boxes, gt_temp_bound, positive_map, time_mask, **kwargs,
    # ):
    def get_loss(
        self, loss, outputs, targets, num_boxes, gt_bbox_slice, **kwargs,
    ):
        loss_map = {
            "boxes": self.loss_boxes,
            "sted": self.loss_sted,
            "guided_attn": self.loss_guided_attn,
            "actioness": self.loss_actioness,
            "conf": self.loss_conf,
            "imgguided_attn": self.loss_img_attn,
        }
        assert loss in loss_map, f"do you really want to compute {loss} loss?"
        # if loss in ["sted", "guided_attn"]:
        #     return loss_map[loss](
        #         outputs, num_boxes, gt_temp_bound, positive_map, time_mask, **kwargs
        #     )
        # if loss == "actioness":
        #     return loss_map[loss](outputs, targets, gt_temp_bound, time_mask, **kwargs)

        if loss == "conf":
            return loss_map[loss](outputs, targets, num_boxes, gt_bbox_slice)

        return loss_map[loss](outputs, targets, num_boxes, **kwargs)

    def forward(self, outputs, targets, durations):
    # def forward(self, outputs, targets, durations, frame_feature, corre_feature, neg_feature, flang_attn, region_feature, neg_region_feature):
        """This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """

        max_duration = max(durations)
        device = outputs["pred_boxes"].device
        gt_bbox_slice, gt_temp_bound = [], []
        # gt_bbox_slice = []
        
        for i_dur, (duration, target) in enumerate(zip(durations, targets)):
            inter = torch.where(target['actioness'])[0].cpu().numpy().tolist()
            gt_temp_bound.append([inter[0],inter[-1]])
            gt_bbox_slice.extend(list(range(i_dur * max_duration + inter[0], i_dur * max_duration + inter[-1] + 1)))
            # gt_bbox_slice.extend(list(range(i_dur * max_duration, i_dur * max_duration + len(target))))
            
        gt_bbox_slice = torch.LongTensor(gt_bbox_slice).to(device)
        outputs["pred_boxes"] = outputs["pred_boxes"][gt_bbox_slice]

        for i_aux in range(len(outputs["aux_outputs"])):
            outputs["aux_outputs"][i_aux]["pred_boxes"] = outputs["aux_outputs"][i_aux]["pred_boxes"][gt_bbox_slice]

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(target['boxs']) for target in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
        
        # computer the temporal mask, used for guided-attn
        b = len(durations)
        time_mask = torch.zeros(b, max(durations)).bool().to(device)
        for i_dur, duration in enumerate(durations):
            time_mask[i_dur, :duration] = True
    
        # positive_map = torch.zeros(time_mask.shape, dtype=torch.bool)
        # for k, idx in enumerate(gt_temp_bound):
        #     if idx[0] < 0:  # empty intersection
        #         continue
        #     positive_map[k][idx[0] : idx[1] + 1].fill_(True)
        #
        # positive_map = positive_map.to(time_mask.device)

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            # losses.update(self.get_loss(loss, outputs, targets, num_boxes, gt_temp_bound, positive_map, time_mask))
            losses.update(self.get_loss(loss, outputs, targets, num_boxes,gt_bbox_slice))
        # losses.update(self.Interframe_contrastive_loss(frame_feature, corre_feature, neg_feature))
        # losses.update(self.Crossmodal_constrastive_loss(flang_attn, region_feature, neg_region_feature))

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if "aux_outputs" in outputs:
            for i, aux_outputs in enumerate(outputs["aux_outputs"]):
                for loss in self.losses:
                    kwargs = {}
                    # l_dict = self.get_loss(loss, aux_outputs, targets, num_boxes, gt_temp_bound, positive_map, time_mask, **kwargs)
                    l_dict = self.get_loss(loss, aux_outputs, targets, num_boxes, gt_bbox_slice,**kwargs)
                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                    losses.update(l_dict)
        
        return losses