import torch
from torch import nn
from torch.nn import functional as F

from .models import register_generator


class BufferList(nn.Module):#构建模型缓冲区，存储每一层对应的时间点序列以及回归范围信息
    """
    Similar to nn.ParameterList, but for buffers

    Taken from https://github.com/facebookresearch/detectron2/blob/master/detectron2/modeling/anchor_generator.py
    """

    def __init__(self, buffers):
        super().__init__()
        for i, buffer in enumerate(buffers):
            # Use non-persistent buffer so the values are not saved in checkpoint
            self.register_buffer(str(i), buffer, persistent=False)

    def __len__(self):
        return len(self._buffers)

    def __iter__(self):
        return iter(self._buffers.values())

@register_generator('point')
class PointGenerator(nn.Module):
    """
        A generator for temporal "points"
        
        max_seq_len can be much larger than the actual seq length
    """
    def __init__(
        self,
        max_seq_len,        # max sequence length that the generator will buffer
        fpn_strides,        # strides of fpn levels
        regression_range,   # regression range (on feature grids)
        use_offset=False    # if to align the points at grid centers
    ):
        super().__init__()
        # sanity check, # fpn levels and length divisible
        fpn_levels = len(fpn_strides)#fpn_levels 被设置为 fpn_strides 的长度，以确定生成的点的数量
        assert len(regression_range) == fpn_levels

        # save params
        self.max_seq_len = max_seq_len
        self.fpn_levels = fpn_levels
        self.fpn_strides = fpn_strides
        self.regression_range = regression_range
        self.use_offset = use_offset

        # generate all points and buffer the list 生成所有的位置点，并将其缓存
        self.buffer_points = self._generate_points()

    def _generate_points(self):
        points_list = []#创建了一个空列表 points_list，用于存储每个FPN级别生成的位置点
        # loop over all points at each pyramid level
        for l, stride in enumerate(self.fpn_strides):#遍历每个FPN级别的步长 (stride) 和回归范围 (reg_range
            reg_range = torch.as_tensor(
                self.regression_range[l], dtype=torch.float)#使用 torch.arange 函数生成一个从 0 开始，以步长 stride 递增，直到 self.max_seq_len 结束的序列，表示该级别上的时间点
            fpn_stride = torch.as_tensor(stride, dtype=torch.float)
            points = torch.arange(0, self.max_seq_len, stride)[:, None]#表示该特征金字塔上的时间点序列，如[0-13824]
            # add offset if necessary (not in our current model)
            if self.use_offset:#如果需要对位置点进行偏移，即 self.use_offset 为 True，则将每个位置点偏移一半的步长。
                points += 0.5 * stride
            # pad the time stamp with additional regression range / stride 扩展回归范围和步长，以便与每个位置点对应。这样，每个位置点的回归范围和步长信息也就保存在了相应的张量中。
            reg_range = reg_range[None].repeat(points.shape[0], 1)
            fpn_stride = fpn_stride[None].repeat(points.shape[0], 1)
            # size: T x 4 (ts, reg_range, stride)
            points_list.append(torch.cat((points, reg_range, fpn_stride), dim=1))#将每个级别的位置点信息添加到 points_list 中。

        return BufferList(points_list)

    def forward(self, feats):
        # feats will be a list of torch tensors
        assert len(feats) == self.fpn_levels
        pts_list = []
        feat_lens = [feat.shape[-1] for feat in feats]
        for feat_len, buffer_pts in zip(feat_lens, self.buffer_points):
            assert feat_len <= buffer_pts.shape[0], "Reached max buffer length for point generator"
            pts = buffer_pts[:feat_len, :]
            pts_list.append(pts)
        return pts_list