import math

import torch
from torch import nn
from torch.nn import functional as F

from .models import register_meta_arch, make_backbone, make_neck, make_generator
from .blocks import MaskedConv1D, Scale, LayerNorm
from .losses import ctr_diou_loss_1d, sigmoid_focal_loss

from ..utils import batched_nms

class ConsistencyLoss:
        
    def __call__(self, logits, targets, mask=None):
        
        preds = F.log_softmax(logits, dim=-1)
        loss = F.nll_loss(preds, targets, reduction='none')
        if mask is not None:
            masked_loss = loss * mask.float()
            return masked_loss.mean()
        return loss.mean()

class PtTransformerClsHead(nn.Module):
    """
    1D Conv heads for classification
    """
    def __init__(
        self,
        input_dim,
        feat_dim,
        num_classes,
        prior_prob=0.01,
        num_layers=3,
        kernel_size=3,
        act_layer=nn.ReLU,
        with_ln=False,
        empty_cls = []
    ):
        super().__init__()
        self.act = act_layer()#创建非线性激活函数

        # build the head
        self.head = nn.ModuleList()#self.head 和 self.norm分别用于存储头部网络的卷积层和归一化层。
        self.norm = nn.ModuleList()
        for idx in range(num_layers-1):#构建多层卷积层
            if idx == 0:#确定当前层的输入维度 in_dim 和输出维度 out_dim
                in_dim = input_dim
                out_dim = feat_dim
            else:
                in_dim = feat_dim
                out_dim = feat_dim
            self.head.append(
                MaskedConv1D(#创建一个 MaskedConv1D 实例，应用适当的填充以保持输入输出尺寸一致，该卷积在初始化时没有改动，可能forward时进行了改动
                    in_dim, out_dim, kernel_size,
                    stride=1,
                    padding=kernel_size//2,
                    bias=(not with_ln)
                )
            )
            if with_ln:
                self.norm.append(LayerNorm(out_dim))
            else:
                self.norm.append(nn.Identity())

        # classifier 创建了一个用于分类的一维卷积层
        self.cls_head = MaskedConv1D(
                feat_dim, num_classes, kernel_size,
                stride=1, padding=kernel_size//2
            )

        # use prior in model initialization to improve stability
        # this will overwrite other weight init 将 prior_prob 设置为 0.01 表示将所有类别的先验概率设置为 0.01，这相当于假设所有类别的出现概率相同，并且相对较低
        if prior_prob > 0:#如果 prior_prob 大于 0，表示要使用先验概率来初始化模型以提高稳定性
            bias_value = -(math.log((1 - prior_prob) / prior_prob))#根据先验概率计算偏置项的值，这个值用于对分类器的偏置进行初始化。
            torch.nn.init.constant_(self.cls_head.conv.bias, bias_value)#使用计算得到的偏置值对分类器的偏置项进行常数初始化，这会覆盖其他的权重初始化方式

        # a quick fix to empty categories:
        # the weights assocaited with these categories will remain unchanged
        # we set their bias to a large negative value to prevent their outputs
        if len(empty_cls) > 0:#如果 empty_cls 列表不为空，表示存在一些空的类别，即它们的权重应该保持不变（通常设为0），以防止它们对模型的输出产生影响。
            bias_value = -(math.log((1 - 1e-6) / 1e-6))
            for idx in empty_cls:
                torch.nn.init.constant_(self.cls_head.conv.bias[idx], bias_value)

    def forward(self, fpn_feats, fpn_masks):
        assert len(fpn_feats) == len(fpn_masks)

        # apply the classifier for each pyramid level
        out_logits = tuple()
        for _, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)):
            cur_out = cur_feat
            for idx in range(len(self.head)):
                cur_out, _ = self.head[idx](cur_out, cur_mask)
                cur_out = self.act(self.norm[idx](cur_out))
            cur_logits, _ = self.cls_head(cur_out, cur_mask)
            out_logits += (cur_logits, )

        # fpn_masks remains the same
        return out_logits


class PtTransformerRegHead(nn.Module):
    """
    Shared 1D Conv heads for regression
    Simlar logic as PtTransformerClsHead with separated implementation for clarity
    """
    def __init__(
        self,
        input_dim,
        feat_dim,
        fpn_levels,
        num_layers=3,
        kernel_size=3,
        act_layer=nn.ReLU,
        with_ln=False
    ):
        super().__init__()
        self.fpn_levels = fpn_levels
        self.act = act_layer()#实例化激活函数

        # build the conv head
        self.head = nn.ModuleList()
        self.norm = nn.ModuleList()
        for idx in range(num_layers-1):#构建2层卷积
            if idx == 0:
                in_dim = input_dim
                out_dim = feat_dim
            else:
                in_dim = feat_dim
                out_dim = feat_dim
            self.head.append(
                MaskedConv1D(
                    in_dim, out_dim, kernel_size,
                    stride=1,
                    padding=kernel_size//2,
                    bias=(not with_ln)
                )
            )
            if with_ln:
                self.norm.append(LayerNorm(out_dim))
            else:
                self.norm.append(nn.Identity())

        self.scale = nn.ModuleList()#用于存储尺度调整层。共6层，对应特征金字塔
        for idx in range(fpn_levels):
            self.scale.append(Scale())

        # segment regression 创建一个用于分段回归的卷积层
        self.offset_head = MaskedConv1D(
                feat_dim, 2, kernel_size,
                stride=1, padding=kernel_size//2
            )

    def forward(self, fpn_feats, fpn_masks):
        assert len(fpn_feats) == len(fpn_masks)
        assert len(fpn_feats) == self.fpn_levels

        # apply the classifier for each pyramid level
        out_offsets = tuple()
        for l, (cur_feat, cur_mask) in enumerate(zip(fpn_feats, fpn_masks)):
            cur_out = cur_feat
            for idx in range(len(self.head)):
                cur_out, _ = self.head[idx](cur_out, cur_mask)
                cur_out = self.act(self.norm[idx](cur_out))
            cur_offsets, _ = self.offset_head(cur_out, cur_mask)
            out_offsets += (F.relu(self.scale[l](cur_offsets)), )

        # fpn_masks remains the same
        return out_offsets
from torch.autograd import Function
# todo
class GradientReversalFn(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        # Store context for backprop
        ctx.alpha = alpha
        # Forward pass is a no-op
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        # Backward pass is just to -alpha the gradient
        output = grad_output.neg() * ctx.alpha

        # Must return same number as inputs to forward()
        return output, None

@register_meta_arch("LocPointTransformer")
class PtTransformer(nn.Module):
    """
        Transformer based model for single stage action localization
    """
    def __init__(
        self,
        backbone_type,         # a string defines which backbone we use
        fpn_type,              # a string defines which fpn we use
        backbone_arch,         # a tuple defines #layers in embed / stem / branch
        scale_factor,          # scale factor between branch layers
        input_dim,             # input feat dim
        max_seq_len,           # max sequence length (used for training)
        max_buffer_len_factor, # max buffer size (defined a factor of max_seq_len)
        n_head,                # number of heads for self-attention in transformer
        n_mha_win_size,        # window size for self attention; -1 to use full seq
        embd_kernel_size,      # kernel size of the embedding network
        embd_dim,              # output feat channel of the embedding network
        embd_with_ln,          # attach layernorm to embedding network
        fpn_dim,               # feature dim on FPN
        fpn_with_ln,           # if to apply layer norm at the end of fpn
        fpn_start_level,       # start level of fpn
        head_dim,              # feature dim for head
        regression_range,      # regression range on each level of FPN
        head_num_layers,       # number of layers in the head (including the classifier)
        head_kernel_size,      # kernel size for reg/cls heads
        head_with_ln,          # attache layernorm to reg/cls heads
        use_abs_pe,            # if to use abs position encoding
        use_rel_pe,            # if to use rel position encoding
        num_classes,           # number of action classes
        train_cfg,             # other cfg for training
        test_cfg               # other cfg for testing
    ):
        super().__init__()#初始化
         # re-distribute params to backbone / neck / head
        self.fpn_strides = [scale_factor**i for i in range(
            fpn_start_level, backbone_arch[-1]+1
        )]#根据给定的scale_factor和backbone_arch，计算了特征金字塔网络的步长。这个步长是一个列表，代表了从fpn_start_level到backbone_arch[-1]的各级的步长。这些步长用于后续计算特征金字塔网络的尺度。
        self.reg_range = regression_range#表示了每个级别的特征金字塔网络的回归范围。
        assert len(self.fpn_strides) == len(self.reg_range)#确保特征金字塔网络的步长列表长度与回归范围列表长度相同，以保证每个级别都有对应的回归范围
        self.scale_factor = scale_factor#记录相邻金字塔间的缩放比例
        # #classes = num_classes + 1 (background) with last category as background
        # e.g., num_classes = 10 -> 0, 1, ..., 9 as actions, 10 as background
        self.num_classes = num_classes#num_class为类别+1，表示所有前景分类与背景

        # check the feature pyramid and local attention window size 设置了最大序列长度和自注意力窗口大小
        self.max_seq_len = max_seq_len
        if isinstance(n_mha_win_size, int):#如果n_mha_win_size是整数，则所有级别的窗口大小相同；否则，根据级别设置不同的窗口大小。
            self.mha_win_size = [n_mha_win_size]*(1 + backbone_arch[-1])
        else:
            assert len(n_mha_win_size) == (1 + backbone_arch[-1])
            self.mha_win_size = n_mha_win_size
        max_div_factor = 1#max_div_factor初始化为1，用于记录最大的步长因子。
        for l, (s, w) in enumerate(zip(self.fpn_strides, self.mha_win_size)):
            stride = s * (w // 2) * 2 if w > 1 else s#计算步长stride，其中w是窗口大小。如果窗口大小大于1，则步长需要乘以窗口大小的一半并乘以2，否则步长就是原始的值。
            assert max_seq_len % stride == 0, "max_seq_len must be divisible by fpn stride and window size"#确保最大序列长度max_seq_len能够被步长整除，否则抛出异常
            if max_div_factor < stride:#如果当前步长stride大于max_div_factor，则更新max_div_factor为当前步长
                max_div_factor = stride
        self.max_div_factor = max_div_factor#最后，将max_div_factor赋值给self.max_div_factor

        # training time config 通过访问train  _cfg字典，设置了一系列模型在训练时的配置参数，包括中心样本采样方式、中心样本采样半径、损失权重、类别先验概率、dropout等。
        self.train_center_sample = train_cfg['center_sample']
        assert self.train_center_sample in ['radius', 'none']
        self.train_center_sample_radius = train_cfg['center_sample_radius']
        self.train_loss_weight = train_cfg['loss_weight']
        self.train_cls_prior_prob = train_cfg['cls_prior_prob']
        self.train_dropout = train_cfg['dropout']
        self.train_droppath = train_cfg['droppath']
        self.train_label_smoothing = train_cfg['label_smoothing']

        # test time config
        self.test_pre_nms_thresh = test_cfg['pre_nms_thresh']
        self.test_pre_nms_topk = test_cfg['pre_nms_topk']
        self.test_iou_threshold = test_cfg['iou_threshold']
        self.test_min_score = test_cfg['min_score']
        self.test_max_seg_num = test_cfg['max_seg_num']
        self.test_nms_method = test_cfg['nms_method']
        assert self.test_nms_method in ['soft', 'hard', 'none']
        self.test_duration_thresh = test_cfg['duration_thresh']
        self.test_multiclass_nms = test_cfg['multiclass_nms']
        self.test_nms_sigma = test_cfg['nms_sigma']
        self.test_voting_thresh = test_cfg['voting_thresh']

        # we will need a better way to dispatch the params to backbones / necks
        # backbone network: conv + transformer
        assert backbone_type in ['convTransformer', 'conv']
        if backbone_type == 'convTransformer':#如果backbone_type是'convTransformer'，则调用make_backbone函数创建一个卷积+Transformer结构的backbone
            self.backbone = make_backbone(
                'convTransformer',
                **{
                    'n_in' : input_dim,
                    'n_embd' : embd_dim,
                    'n_head': n_head,
                    'n_embd_ks': embd_kernel_size,
                    'max_len': max_seq_len,
                    'arch' : backbone_arch,
                    'mha_win_size': self.mha_win_size,
                    'scale_factor' : scale_factor,
                    'with_ln' : embd_with_ln,
                    'attn_pdrop' : 0.0,
                    'proj_pdrop' : self.train_dropout,
                    'path_pdrop' : self.train_droppath,
                    'use_abs_pe' : use_abs_pe,
                    'use_rel_pe' : use_rel_pe
                }
            )
        else:#如果backbone_type是'conv'，则调用make_backbone函数创建一个只包含卷积层的backbone
            self.backbone = make_backbone(
                'conv',
                **{
                    'n_in': input_dim,
                    'n_embd': embd_dim,
                    'n_embd_ks': embd_kernel_size,
                    'arch': backbone_arch,
                    'scale_factor': scale_factor,
                    'with_ln' : embd_with_ln
                }
            )
        if isinstance(embd_dim, (list, tuple)):
            embd_dim = sum(embd_dim)

        # fpn network: convs 似乎是创建层归一化模块，不知道为什么叫做FPN？
        assert fpn_type in ['fpn', 'identity']
        self.neck = make_neck(
            fpn_type,
            **{
                'in_channels' : [embd_dim] * (backbone_arch[-1] + 1),
                'out_channel' : fpn_dim,
                'scale_factor' : scale_factor,
                'start_level' : fpn_start_level,
                'with_ln' : fpn_with_ln
            }
        )

        # location generator: points 生成每个特征金字塔层对应的时间序列，回归范围，步幅
        self.point_generator = make_generator(
            'point',
            **{
                'max_seq_len' : max_seq_len * max_buffer_len_factor,
                'fpn_strides' : self.fpn_strides,
                'regression_range' : self.reg_range
            }
        )

        # classfication and regerssion heads 用于分类任务的头部模块，输出形状为[B,num_class]，表示num_class个类（包括前景背景）
        self.cls_head = PtTransformerClsHead(
            fpn_dim, head_dim, self.num_classes,
            kernel_size=head_kernel_size,
            prior_prob=self.train_cls_prior_prob,
            with_ln=head_with_ln,
            num_layers=head_num_layers,
            empty_cls=train_cfg['head_empty_cls']
        )
        self.reg_head = PtTransformerRegHead(#用于回归任务的头部模块，输出形状为[B,2]
            fpn_dim, head_dim, len(self.fpn_strides),
            kernel_size=head_kernel_size,
            num_layers=head_num_layers,
            with_ln=head_with_ln
        )
        # # todo
        self.domain_classifier = PtTransformerRegHead(
            fpn_dim, head_dim, len(self.fpn_strides),
            kernel_size=head_kernel_size,
            num_layers=head_num_layers,
            with_ln=head_with_ln
        )
        # maintain an EMA of #foreground to stabilize the loss normalizer
        # useful for small mini-batch training
        self.loss_normalizer = train_cfg['init_loss_norm']#设置损失标准化器的初始值。在训练过程中，通常会对损失进行标准化，使得不同损失的范围统一，以便更好地进行优化
        self.loss_normalizer_momentum = 0.9#设置损失标准化器的动量。这个参数用于控制损失标准化器的指数移动平均（EMA），以稳定化损失标准化器的值
        # TODO 初始化freeMatch的参数
        self.p_t=torch.ones(self.num_classes)/self.num_classes
        self.label_hist=torch.ones(self.num_classes)/self.num_classes
        self.tau_t=self.p_t.mean()
        self.sat_ema=0.999
        self.criterion = ConsistencyLoss()
        self.ulb_loss_ratio=1
        self.ent_loss_ratio=0.01

    @property
    def device(self):
        # a hacky way to get the device type
        # will throw an error if parameters are on different devices
        return list(set(p.device for p in self.parameters()))[0]

    def forward(self, video_list,grl_lambda=0,domain=None,finetune=False,freeMatch=False,cluster=False,kmeans=None,pca=None,sum_score=None):
        # 在一个batch中分割源相似集和源不相似集
        if sum_score!= None:
            batch_sum_score=[]
            for video_info in video_list:
                video_id=video_info['video_id']
                batch_sum_score.append(sum_score[video_id])
            batch_sum_score.sort()
            score_threshold=batch_sum_score[-5]
            is_sim=[]
            for video_info in video_list:
                video_id=video_info['video_id']
                if sum_score[video_id]>score_threshold:
                    is_sim.append(True)
                else:
                    is_sim.append(False)
        domain_pred=[]
        # batch the video list into feats (B, C, T) and masks (B, 1, T)
        batched_inputs, batched_masks = self.preprocessing(video_list)
        # batched_inputs[B:2,C:2048,T:2304]
        # forward the network (backbone -> neck -> heads)
        feats, masks = self.backbone(batched_inputs, batched_masks)
        if cluster==True:
            return feats
        # feats([2,512,2304],[2,512,1152],[2,512,576],[2,512,288],[2,512,144],[2,512,72])
        fpn_feats, fpn_masks = self.neck(feats, masks)#对特征金字塔的每一层都做层归一化
        # fpn_feats([2,512,2304],[2,512,1152],[2,512,576],[2,512,288],[2,512,144],[2,512,72])
        # compute the point coordinate along the FPN
        # this is used for computing the GT or decode the final results
        # points: List[T x 4] with length = # fpn levels
        # (shared across all samples in the mini-batch)
        points = self.point_generator(fpn_feats)
        # # todo
        if self.training and finetune==False:
            domain_pred=self.domain_classifier(fpn_feats, fpn_masks)
            domain_pred = [x.permute(0, 2, 1) for x in domain_pred]
            for i in range(len(domain_pred)):
                domain_pred[i] = GradientReversalFn.apply(domain_pred[i], grl_lambda).cuda()
                _,in_c,_=domain_pred[i].shape
                conv1 = nn.Conv1d(in_channels=in_c,out_channels = 1, kernel_size =1).cuda()
                domain_pred[i]=conv1(domain_pred[i])
                domain_pred[i]=domain_pred[i].squeeze(dim=1)
                domain_pred[i]=nn.LogSoftmax(dim=1)(domain_pred[i])
            if domain=='target':
                return domain_pred

        # out_cls: List[B, #cls + 1, T_i] ([2,20,2304],[2,20,1152],[2,20,576],[2,20,288],[2,20,144],[2,20,72])
        out_cls_logits = self.cls_head(fpn_feats, fpn_masks)
        # out_offset: List[B, 2, T_i] ([2,2,2304],[2,2,1152],[2,2,576],[2,2,288],[2,2,144],[2,2,72])
        out_offsets = self.reg_head(fpn_feats, fpn_masks)

        # permute the outputs
        # out_cls: F List[B, #cls, T_i] -> F List[B, T_i, #cls]
        out_cls_logits = [x.permute(0, 2, 1) for x in out_cls_logits]
        # out_offset: F List[B, 2 (xC), T_i] -> F List[B, T_i, 2 (xC)]
        out_offsets = [x.permute(0, 2, 1) for x in out_offsets]
        # fpn_masks: F list[B, 1, T_i] -> F List[B, T_i]
        fpn_masks = [x.squeeze(1) for x in fpn_masks]

        # return loss during training
        if self.training:
            # generate segment/lable List[N x 2] / List[N] with length = B
            assert video_list[0]['segments'] is not None, "GT action labels does not exist"
            assert video_list[0]['labels'] is not None, "GT action labels does not exist"
            gt_segments = [x['segments'] for x in video_list]
            gt_labels = [x['labels'] for x in video_list]

            # compute the gt labels for cls & reg
            # list of prediction targets
            gt_cls_labels, gt_offsets = self.label_points(
                points, gt_segments, gt_labels)


            if pca != None:
                feats4cluster = torch.cat(feats, 2)
                feats4cluster = feats4cluster.permute(0, 2 , 1)
                feats4cluster = feats4cluster.reshape(-1, 512)
                feats4cluster = pca.fit_transform(feats4cluster.cpu().detach()) 
                cluster_labels = kmeans.predict(feats4cluster)

            # compute the loss and return
            # losses = self.losses(
            #     fpn_masks,
            #     out_cls_logits, out_offsets,
            #     gt_cls_labels, gt_offsets,freeMatch=freeMatch,cluster_labels=cluster_labels,is_sim=is_sim
            # )
            
            losses = self.losses(
                fpn_masks,
                out_cls_logits, out_offsets,
                gt_cls_labels, gt_offsets,freeMatch=freeMatch,
                # cluster_labels=cluster_labels,is_sim=is_sim
            )
            if finetune==False:
                return losses,domain_pred
            else:
                return losses

        else:
            # decode the actions (sigmoid / stride, etc)
            results = self.inference(
                video_list, points, fpn_masks,
                out_cls_logits, out_offsets
            )
            return results

    @torch.no_grad()
    def preprocessing(self, video_list, padding_val=0.0):
        """
            Generate batched features and masks from a list of dict items
        """
        feats = [x['feats'] for x in video_list]
        feats_lens = torch.as_tensor([feat.shape[-1] for feat in feats])
        max_len = feats_lens.max(0).values.item()

        if self.training:
            assert max_len <= self.max_seq_len, "Input length must be smaller than max_seq_len during training"
            # if max_len > self.max_seq_len:
            #     print("warning!!!! Input length must be smaller than max_seq_len during training")
            #     feats=feats[:,:,:self.max_seq_len-1]
            # set max_len to self.max_seq_len
            max_len = self.max_seq_len
            # batch input shape B, C, T
            batch_shape = [len(feats), feats[0].shape[0], max_len]
            batched_inputs = feats[0].new_full(batch_shape, padding_val)
            for feat, pad_feat in zip(feats, batched_inputs):
                pad_feat[..., :feat.shape[-1]].copy_(feat)
        else:
            assert len(video_list) == 1, "Only support batch_size = 1 during inference"
            # input length < self.max_seq_len, pad to max_seq_len
            if max_len <= self.max_seq_len:
                max_len = self.max_seq_len
            else:
                # pad the input to the next divisible size
                stride = self.max_div_factor
                max_len = (max_len + (stride - 1)) // stride * stride
            padding_size = [0, max_len - feats_lens[0]]
            batched_inputs = F.pad(
                feats[0], padding_size, value=padding_val).unsqueeze(0)

        # generate the mask
        batched_masks = torch.arange(max_len)[None, :] < feats_lens[:, None]

        # push to device
        batched_inputs = batched_inputs.to(self.device)
        batched_masks = batched_masks.unsqueeze(1).to(self.device)

        return batched_inputs, batched_masks

    @torch.no_grad()
    def label_points(self, points, gt_segments, gt_labels):
        # concat points on all fpn levels List[T x 4] -> F T x 4
        # This is shared for all samples in the mini-batch
        num_levels = len(points)
        concat_points = torch.cat(points, dim=0)
        gt_cls, gt_offset = [], []

        # loop over each video sample
        for gt_segment, gt_label in zip(gt_segments, gt_labels):
            cls_targets, reg_targets = self.label_points_single_video(
                concat_points, gt_segment, gt_label
            )
            # append to list (len = # images, each of size FT x C)
            gt_cls.append(cls_targets)
            gt_offset.append(reg_targets)

        return gt_cls, gt_offset

    @torch.no_grad()
    def label_points_single_video(self, concat_points, gt_segment, gt_label):
        # concat_points : F T x 4 (t, regression range, stride)
        # gt_segment : N (#Events) x 2
        # gt_label : N (#Events) x 1
        num_pts = concat_points.shape[0]#记录特征金字塔合并后有多少个预测的时间点
        num_gts = gt_segment.shape[0]#记录这个视频有多少个anno

        # corner case where current sample does not have actions
        if num_gts == 0:
            cls_targets = gt_segment.new_full((num_pts, self.num_classes), 0)
            reg_targets = gt_segment.new_zeros((num_pts, 2))
            return cls_targets, reg_targets

        # compute the lengths of all segments -> F T x N
        lens = gt_segment[:, 1] - gt_segment[:, 0]
        lens = lens[None, :].repeat(num_pts, 1)

        # compute the distance of every point to each segment boundary
        # auto broadcasting for all reg target-> F T x N x2
        gt_segs = gt_segment[None].expand(num_pts, num_gts, 2)
        left = concat_points[:, 0, None] - gt_segs[:, :, 0]#用于计算每个预测点到真实标注左边界的距离
        right = gt_segs[:, :, 1] - concat_points[:, 0, None]#用于计算每个预测点到真实标注右边界的距离
        reg_targets = torch.stack((left, right), dim=-1)

        if self.train_center_sample == 'radius':
            # center of all segments F T x N
            center_pts = 0.5 * (gt_segs[:, :, 0] + gt_segs[:, :, 1])#存储了真实标注的中心点。
            # center sampling based on stride radius
            # compute the new boundaries:
            # concat_points[:, 3] stores the stride
            t_mins = \
                center_pts - concat_points[:, 3, None] * self.train_center_sample_radius
            t_maxs = \
                center_pts + concat_points[:, 3, None] * self.train_center_sample_radius
            # prevent t_mins / maxs from over-running the action boundary
            # left: torch.maximum(t_mins, gt_segs[:, :, 0])
            # right: torch.minimum(t_maxs, gt_segs[:, :, 1])
            # F T x N (distance to the new boundary)
            cb_dist_left = concat_points[:, 0, None] \
                           - torch.maximum(t_mins, gt_segs[:, :, 0])
            cb_dist_right = torch.minimum(t_maxs, gt_segs[:, :, 1]) \
                            - concat_points[:, 0, None]#根际新计算的左右界限处理所有时间点的左右边界的距离
            # F T x N x 2
            center_seg = torch.stack(
                (cb_dist_left, cb_dist_right), -1)
            # F T x N
            inside_gt_seg_mask = center_seg.min(-1)[0] > 0#判断每个预测点是否在新的边界范围内。确保模型专注于有意义的区域。
        else:
            # inside an gt action
            inside_gt_seg_mask = reg_targets.min(-1)[0] > 0

        # limit the regression range for each location
        max_regress_distance = reg_targets.max(-1)[0]#定义最大回归长度，即当前时间点离左右边界的最远距离
        # F T x N
        inside_regress_range = torch.logical_and(
            (max_regress_distance >= concat_points[:, 1, None]),
            (max_regress_distance <= concat_points[:, 2, None])
        )#判断该时间点所在的特征金字塔层对应的回归范围是否包含最大最大回归长度，确保模型在合理范围内预测

        # if there are still more than one actions for one moment
        # pick the one with the shortest duration (easiest to regress)
        lens.masked_fill_(inside_gt_seg_mask==0, float('inf'))#确保选择的时间点是正确的
        lens.masked_fill_(inside_regress_range==0, float('inf'))#确保选择的回归范围是正确的
        # F T x N -> F T
        min_len, min_len_inds = lens.min(dim=1)#如果同一片段有多个动作，选择动作时间最短的作为gt
# TODO 这一段不太理解
        # corner case: multiple actions with very similar durations (e.g., THUMOS14)
        min_len_mask = torch.logical_and(
            (lens <= (min_len[:, None] + 1e-3)), (lens < float('inf'))
        ).to(reg_targets.dtype)

        # cls_targets: F T x C; reg_targets F T x 2 生成分类目标 (cls_targets)
        gt_label_one_hot = F.one_hot(
            gt_label, self.num_classes
        ).to(reg_targets.dtype)
        cls_targets = min_len_mask @ gt_label_one_hot
        # to prevent multiple GT actions with the same label and boundaries
        cls_targets.clamp_(min=0.0, max=1.0)#确保 cls_targets 中的值在 0 和 1 之间。这是为了防止在某些情况下可能出现的溢出或错误。
        # OK to use min_len_inds 生成回归目标 (reg_targets)
        reg_targets = reg_targets[range(num_pts), min_len_inds]#每个预测点有两个值，表示距离左右边界的距离。
        # normalization based on stride 将 reg_targets 归一化，除以 concat_points 中的步长。这样做的目的是为了使回归目标与特征图的尺度保持一致
        reg_targets /= concat_points[:, 3, None]
        # for i in range(1134):
        #     for j in range(7):
        #         if cls_targets[i][j]==1:
        #             print(i,j)
        return cls_targets, reg_targets

    def update_sat_params(self,logits_ulb_w,tau_t,p_t,label_hist):
        probs_ulb_w=torch.softmax(logits_ulb_w,dim=-1)
        max_probs_w,max_idx_w=torch.max(probs_ulb_w,dim=-1)
        tau_t=tau_t*self.sat_ema+(1.-self.sat_ema)*max_probs_w.mean()
        p_t=p_t*self.sat_ema+(1.-self.sat_ema)*probs_ulb_w.mean(dim=0)
        histogram = torch.bincount(max_idx_w, minlength=p_t.shape[0]).to(p_t.dtype)
        label_hist = label_hist * self.sat_ema + (1. - self.sat_ema) * (histogram / histogram.sum())
        return tau_t, p_t, label_hist
    
    @staticmethod
    def __check__nans__(x):
        x[x == float('inf')] = 0.0
        return x

    def losses(
        self, fpn_masks,
        out_cls_logits, out_offsets,
        gt_cls_labels, gt_offsets,freeMatch=False,cluster_labels=None,is_sim=None
    ):
        # fpn_masks, out_*: F (List) [B, T_i, C]
        # gt_* : B (list) [F T, C]
        # fpn_masks -> (B, FT)
        valid_mask = torch.cat(fpn_masks, dim=1)

        # 1. classification loss
        # stack the list -> (B, FT) -> (# Valid, )
        gt_cls = torch.stack(gt_cls_labels)
        pos_mask = torch.logical_and((gt_cls.sum(-1) > 0), valid_mask)
        # 得到的 pos_mask 是一个形状为 (B, F) 的布尔掩码，其中 True 表示对应的视频段中至少存在一个正类别，而 False 表示没有正类别
        # cat the predicted offsets -> (B, FT, 2 (xC)) -> # (#Pos, 2 (xC))
        # ls=torch.cat(out_offsets, dim=1)#此时ls形状为[2,1134,2] pos_mask形状为[2,1134]
        pred_offsets = torch.cat(out_offsets, dim=1)[pos_mask]
        gt_offsets = torch.stack(gt_offsets)[pos_mask]

        # update the loss normalizer
        num_pos = pos_mask.sum().item()
        self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
            1 - self.loss_normalizer_momentum
        ) * max(num_pos, 1)
        
        # TODO
        # 0. loss_sat,loss_saf for 伪标签 
        if freeMatch == True:

            # 0.1 loss_sat
            logits_ulb_w=torch.sigmoid(torch.cat(out_cls_logits, dim=1))
            logits_ulb_s=gt_cls #伪标签 batchsize*feat_len*class_num
            logits_ulb_w=torch.flatten(logits_ulb_w, start_dim=0, end_dim=-2)
            logits_ulb_s=torch.flatten(logits_ulb_s, start_dim=0, end_dim=-2)
            # self.p_t=self.p_t.to(f'cuda:{logits_ulb_w.device.index}')
            self.p_t=self.p_t.to(self.device)
            # self.label_hist=self.label_hist.to(f'cuda:{logits_ulb_w.device.index}')
            # self.tau_t=self.tau_t.to(f'cuda:{logits_ulb_w.device.index}')
            self.label_hist=self.label_hist.to(self.device)
            self.tau_t=self.tau_t.to(self.device)

            tau_t, p_t, label_hist = self.update_sat_params(logits_ulb_w, self.tau_t, self.p_t, self.label_hist)
            
            logits_ulb_w = logits_ulb_w.detach()
            probs_ulb_w=torch.softmax(logits_ulb_w,dim=-1)
            max_probs_w,max_idx_w=torch.max(probs_ulb_w,dim=-1)
            tau_t_c = (p_t / torch.max(p_t, dim=-1)[0])
            mask = max_probs_w.ge(tau_t * tau_t_c[max_idx_w]).to(max_probs_w.dtype)
            loss_sat = self.criterion(logits_ulb_s, max_idx_w, mask=mask)
            self.p_t=p_t
            # self.label_hist=label_hist
            # self.tau_t=tau_t

            # 0.2 loss_saf
            # Take high confidence examples based on Eq 7 of the paper
            # logits_ulb_s = logits_ulb_s[mask.bool()]
            # probs_ulb_s = torch.softmax(logits_ulb_s, dim=-1)
            # max_idx_s = torch.argmax(probs_ulb_s, dim=-1)
            
            # # Calculate the histogram of strong logits acc. to Eq. 9
            # # Cast it to the dtype of the strong logits to remove the error of division of float by long
            # histogram = torch.bincount(max_idx_s, minlength=logits_ulb_s.shape[1]).to(logits_ulb_s.dtype)
            # histogram /= histogram.sum()

            # # Eq. 11 of the paper.
            # p_t = p_t.reshape(1, -1)
            # label_hist = label_hist.reshape(1, -1)
            
            # # Divide by the Sum Norm for both the weak and strong augmentations
            # scaler_p_t = self.__check__nans__(1 / label_hist).detach()
            # modulate_p_t = p_t * scaler_p_t
            # modulate_p_t /= modulate_p_t.sum(dim=-1, keepdim=True)
            
            # scaler_prob_s = self.__check__nans__(1 / histogram).detach()
            # modulate_prob_s = probs_ulb_s.mean(dim=0, keepdim=True) * scaler_prob_s
            # modulate_prob_s /= modulate_prob_s.sum(dim=-1, keepdim=True)
            
            # # Cross entropy loss between two Sum Norm logits. 
            # loss_saf = (modulate_p_t * torch.log(modulate_prob_s + 1e-9)).sum(dim=1).mean()

            # loss_mean_teacher=self.ulb_loss_ratio*loss_sat+self.ent_loss_ratio*loss_saf
            # loss_mean_teacher=0


        # gt_cls is already one hot encoded now, simply masking out
        gt_target = gt_cls[valid_mask]

        # optinal label smoothing
        gt_target *= (1 - self.train_label_smoothing)
        gt_target += self.train_label_smoothing / (self.num_classes + 1)

        # 0.5 BCE损失
        #首先根据聚类结果计算源不相似集的
        if cluster_labels!=None:
            unsim_index=[]
            for i in range(len(is_sim)):
                if is_sim[i] == False:
                    unsim_index.append(i)
            unsim_bec_loss=0
            for i in unsim_index:
                logits_1video=[]
                for t in out_cls_logits:
                    logits_1video.append(t[i])
                cluster_label_index=0
                loss_1video=0
                for logits_1size in logits_1video:
                    logits_1size=F.softmax(logits_1size, dim=1)
                    batch_logits_t=[]
                    batch_logits_t1=[]
                    batch_cluster_label=[]
                    batch_cluster_label1=[]
                    for t in range(0,len(logits_1size),2):
                        batch_logits_t.append(logits_1size[t])
                        batch_logits_t1.append(logits_1size[t+1])
                        batch_cluster_label.append(cluster_labels[i*valid_mask.shape[-1]+cluster_label_index])
                        batch_cluster_label1.append(cluster_labels[i*valid_mask.shape[-1]+cluster_label_index+1])
                        cluster_label_index+=2
                    batch_logits_t=torch.stack(batch_logits_t)
                    batch_logits_t1=torch.stack(batch_logits_t1)
                    # batch_logits_t_m = batch_logits_t.repeat(batch_logits_t.size(0), 1)
                    # batch_logits_t1_m = batch_logits_t1.repeat(1, batch_logits_t1.size(0)).view(-1, batch_logits_t1.size(1))
                    simi=[]
                    for t in range(len(batch_cluster_label)):
                        if batch_cluster_label[t]==batch_cluster_label1[t]:
                            simi.append(1)
                        else:
                            simi.append(-1)
                    simi=torch.tensor(simi).to(self.device)
                    assert len(batch_logits_t)==len(batch_logits_t1)==len(simi), 'Wrong input size:{0},{1},{2}'.format(str(len(batch_logits_t)),str(len(batch_logits_t1)),str(len(simi)))
                    P = batch_logits_t.mul_(batch_logits_t1)
                    P = P.sum(1)
                    P.mul_(simi).add_(simi.eq(-1).type_as(P))
                    neglogP = -P.add_(1e-7).log_()
                    loss_1size=neglogP.mean()
                    loss_1video+=loss_1size
                unsim_bec_loss+=(loss_1video/len(out_cls_logits))
            unsim_bec_loss/=len(unsim_index)
            #根据伪标签结果计算源相似集的
            sim_index=[]
            for i in range(len(is_sim)):
                if is_sim[i] == True:
                    sim_index.append(i)
            sim_bec_loss=0
            for i in sim_index:
                logits_1video=[]
                for t in out_cls_logits:
                    logits_1video.append(t[i])
                pseudo_label_index=0
                loss_1video=0
                for logits_1size in logits_1video:
                    logits_1size=F.softmax(logits_1size, dim=1)
                    batch_logits_t=[]
                    batch_logits_t1=[]
                    batch_pseudo_label=[]
                    batch_pseudo_label1=[]
                    for t in range(0,len(logits_1size),2):
                        batch_logits_t.append(logits_1size[t])
                        batch_logits_t1.append(logits_1size[t+1])
                        batch_pseudo_label.append(torch.argmax(gt_cls_labels[i][pseudo_label_index], -1))
                        batch_pseudo_label1.append(torch.argmax(gt_cls_labels[i][pseudo_label_index+1], -1))
                        pseudo_label_index+=2
                    batch_logits_t=torch.stack(batch_logits_t)
                    batch_logits_t1=torch.stack(batch_logits_t1)
                    # batch_logits_t_m = batch_logits_t.repeat(batch_logits_t.size(0), 1)
                    # batch_logits_t1_m = batch_logits_t1.repeat(1, batch_logits_t1.size(0)).view(-1, batch_logits_t1.size(1))
                    simi=[]
                    for t in range(len(batch_pseudo_label)):
                        if batch_pseudo_label[t]==batch_pseudo_label1[t]:
                            simi.append(1)
                        else:
                            simi.append(-1)
                    simi=torch.tensor(simi).to(self.device)
                    assert len(batch_logits_t)==len(batch_logits_t1)==len(simi), 'Wrong input size:{0},{1},{2}'.format(str(len(batch_logits_t)),str(len(batch_logits_t1)),str(len(simi)))
                    P = batch_logits_t.mul_(batch_logits_t1)
                    P = P.sum(1)
                    P.mul_(simi).add_(simi.eq(-1).type_as(P))
                    neglogP = -P.add_(1e-7).log_()
                    loss_1size=neglogP.mean()
                    loss_1video+=loss_1size
                sim_bec_loss+=(loss_1video/len(out_cls_logits))
            sim_bec_loss/=len(sim_index)
                # logits_1video=torch.cat(logits_1video,1)

        # 1. focal loss
        cls_loss = sigmoid_focal_loss(
            torch.cat(out_cls_logits, dim=1)[valid_mask],
            gt_target,
            reduction='sum'
        )
        cls_loss /= self.loss_normalizer

        # 2. regression using IoU/GIoU loss (defined on positive samples)
        if num_pos == 0:
            reg_loss = 0 * pred_offsets.sum()
        else:
            # giou loss defined on positive samples
            reg_loss = ctr_diou_loss_1d(
                pred_offsets,
                gt_offsets,
                reduction='sum'
            )
            reg_loss /= self.loss_normalizer

        if self.train_loss_weight > 0:
            loss_weight = self.train_loss_weight
        else:
            loss_weight = cls_loss.detach() / max(reg_loss.item(), 0.01)

        # return a dict of losses
        # final_loss = cls_loss + reg_loss * loss_weight+loss_mean_teacher + sim_bec_loss + unsim_bec_loss
        final_loss = cls_loss + reg_loss * loss_weight
        return {'cls_loss'   : cls_loss,
                'reg_loss'   : reg_loss,
                'final_loss' : final_loss}

    @torch.no_grad()
    def inference(
        self,
        video_list,
        points, fpn_masks,
        out_cls_logits, out_offsets
    ):
        # video_list B (list) [dict]
        # points F (list) [T_i, 4]
        # fpn_masks, out_*: F (List) [B, T_i, C]
        results = []

        # 1: gather video meta information
        vid_idxs = [x['video_id'] for x in video_list]
        vid_fps = [x['fps'] for x in video_list]
        vid_lens = [x['duration'] for x in video_list]
        vid_ft_stride = [x['feat_stride'] for x in video_list]
        vid_ft_nframes = [x['feat_num_frames'] for x in video_list]

        # 2: inference on each single video and gather the results
        # upto this point, all results use timestamps defined on feature grids
        for idx, (vidx, fps, vlen, stride, nframes) in enumerate(
            zip(vid_idxs, vid_fps, vid_lens, vid_ft_stride, vid_ft_nframes)
        ):
            # gather per-video outputs
            cls_logits_per_vid = [x[idx] for x in out_cls_logits]
            offsets_per_vid = [x[idx] for x in out_offsets]
            fpn_masks_per_vid = [x[idx] for x in fpn_masks]
            # inference on a single video (should always be the case)
            results_per_vid = self.inference_single_video(
                points, fpn_masks_per_vid,
                cls_logits_per_vid, offsets_per_vid
            )
            # pass through video meta info
            results_per_vid['video_id'] = vidx
            results_per_vid['fps'] = fps
            results_per_vid['duration'] = vlen
            results_per_vid['feat_stride'] = stride
            results_per_vid['feat_num_frames'] = nframes
            results.append(results_per_vid)

        # step 3: postprocssing
        results = self.postprocessing(results)

        return results

    @torch.no_grad()
    def inference_single_video(
        self,
        points,
        fpn_masks,
        out_cls_logits,
        out_offsets,
    ):
        # points F (list) [T_i, 4]
        # fpn_masks, out_*: F (List) [T_i, C]
        segs_all = []
        scores_all = []
        cls_idxs_all = []

        # loop over fpn levels
        for cls_i, offsets_i, pts_i, mask_i in zip(
                out_cls_logits, out_offsets, points, fpn_masks
            ):
            # sigmoid normalization for output logits
            pred_prob = (cls_i.sigmoid() * mask_i.unsqueeze(-1)).flatten()

            # Apply filtering to make NMS faster following detectron2
            # 1. Keep seg with confidence score > a threshold
            keep_idxs1 = (pred_prob > self.test_pre_nms_thresh)
            pred_prob = pred_prob[keep_idxs1]
            topk_idxs = keep_idxs1.nonzero(as_tuple=True)[0]

            # 2. Keep top k top scoring boxes only
            num_topk = min(self.test_pre_nms_topk, topk_idxs.size(0))
            pred_prob, idxs = pred_prob.sort(descending=True)
            pred_prob = pred_prob[:num_topk].clone()
            topk_idxs = topk_idxs[idxs[:num_topk]].clone()

            # fix a warning in pytorch 1.9
            pt_idxs =  torch.div(
                topk_idxs, self.num_classes, rounding_mode='floor'
            )
            cls_idxs = torch.fmod(topk_idxs, self.num_classes)

            # 3. gather predicted offsets
            offsets = offsets_i[pt_idxs]
            pts = pts_i[pt_idxs]

            # 4. compute predicted segments (denorm by stride for output offsets)
            seg_left = pts[:, 0] - offsets[:, 0] * pts[:, 3]
            seg_right = pts[:, 0] + offsets[:, 1] * pts[:, 3]
            pred_segs = torch.stack((seg_left, seg_right), -1)

            # 5. Keep seg with duration > a threshold (relative to feature grids)
            seg_areas = seg_right - seg_left
            keep_idxs2 = seg_areas > self.test_duration_thresh

            # *_all : N (filtered # of segments) x 2 / 1
            segs_all.append(pred_segs[keep_idxs2])
            scores_all.append(pred_prob[keep_idxs2])
            cls_idxs_all.append(cls_idxs[keep_idxs2])

        # cat along the FPN levels (F N_i, C)
        segs_all, scores_all, cls_idxs_all = [
            torch.cat(x) for x in [segs_all, scores_all, cls_idxs_all]
        ]
        results = {'segments' : segs_all,
                   'scores'   : scores_all,
                   'labels'   : cls_idxs_all}

        return results

    @torch.no_grad()
    def postprocessing(self, results):
        # input : list of dictionary items
        # (1) push to CPU; (2) NMS; (3) convert to actual time stamps
        processed_results = []
        for results_per_vid in results:
            # unpack the meta info
            vidx = results_per_vid['video_id']
            fps = results_per_vid['fps']
            vlen = results_per_vid['duration']
            stride = results_per_vid['feat_stride']
            nframes = results_per_vid['feat_num_frames']
            # 1: unpack the results and move to CPU
            segs = results_per_vid['segments'].detach().cpu()
            scores = results_per_vid['scores'].detach().cpu()
            labels = results_per_vid['labels'].detach().cpu()
            if self.test_nms_method != 'none':
                # 2: batched nms (only implemented on CPU)
                segs, scores, labels = batched_nms(
                    segs, scores, labels,
                    self.test_iou_threshold,
                    self.test_min_score,
                    self.test_max_seg_num,
                    use_soft_nms = (self.test_nms_method == 'soft'),
                    multiclass = self.test_multiclass_nms,
                    sigma = self.test_nms_sigma,
                    voting_thresh = self.test_voting_thresh
                )
            # 3: convert from feature grids to seconds
            if segs.shape[0] > 0:
                segs = (segs * stride + 0.5 * nframes) / fps
                # truncate all boundaries within [0, duration]
                segs[segs<=0.0] *= 0.0
                segs[segs>=vlen] = segs[segs>=vlen] * 0.0 + vlen
            
            # 4: repack the results
            processed_results.append(
                {'video_id' : vidx,
                 'segments' : segs,
                 'scores'   : scores,
                 'labels'   : labels}
            )

        return processed_results