#---------------------------------------------------------------------------------------#
# LaneSegNet: Map Learning with Lane Segment Perception for Autonomous Driving          #
# Source code: https://github.com/OpenDriveLab/LaneSegNet                               #
# Copyright (c) OpenDriveLab. All rights reserved.                                      #
#---------------------------------------------------------------------------------------#

import copy
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import mmcv
from mmcv.cnn import Linear, build_activation_layer
from mmcv.runner import auto_fp16, force_fp32
from mmcv.utils import TORCH_VERSION, digit_version
from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
from mmcv.runner import auto_fp16
from mmdet.models.builder import HEADS, build_loss
from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet3d.core.bbox.coders import build_bbox_coder
from mmdet.models.dense_heads import AnchorFreeHead
import matplotlib.pyplot as plt
from projects.plugin.models.utils.memory_buffer import StreamTensorMemory, StreamListMemory
from projects.plugin.models.utils.query_update import LaneSegMotionMLP
from matplotlib.path import Path
from projects.plugin.models.utils.utils import gen_sineembed_for_position, SinePositionalEncoding, gen_3D_sineembed_for_position
from projects.plugin.models.utils.query_denoising_new import LanednQueryGenerator
from projects.plugin.models.utils.dn_memory_buffer import DNStreamTensorMemory
from projects.lanesegnet.core.visualizer.lane_vis import draw_annotation_bev
from mmcv.cnn import xavier_init
from projects.lanesegnet.utils.builder import build_wm_bev_constructor
@HEADS.register_module()
class StreamLaneSegHead(AnchorFreeHead):

    def __init__(self,
                 num_classes,
                 in_channels,
                 dn_iter=0,
                 dn_cls_num=3,
                 tolerant_noise=0.2,
                 num_points = 10,
                 noise_decay_scale=[0.7, 0.7, 0.7],
                 stream_dn=False,
                 chamfer_thresh=None,
                 roi_size=(60, 30),
                 num_query=200,
                 with_box_refine=False,
                 with_shared_param=None,
                 transformer=None,
                 wm_bev_constructor=None,
                 bbox_coder=None,
                 num_reg_fcs=2,
                 code_weights=None,
                 bev_h=30,
                 bev_w=30,
                 num_traj_modal=1,
                 pc_range=None,
                 pts_dim=3,
                 sync_cls_avg_factor=False,
                 num_lane_type_classes=3,
                 streaming_cfg=dict(),
                 dn_cfg=dict(),
                 loss_cls=dict(
                     type='CrossEntropyLoss',
                     bg_cls_weight=0.1,
                     use_sigmoid=False,
                     loss_weight=1.0,
                     class_weight=1.0),
                 loss_bbox=dict(type='L1Loss', loss_weight=5.0),
                 loss_mask=dict(type='CrossEntropyLoss', loss_weight=3.0),
                 loss_dice=dict(type='DiceLoss', loss_weight=3.0),
                 loss_lane_type=dict(
                     type='CrossEntropyLoss',
                     use_sigmoid=True,
                     loss_weight=1.0),
                loss_dn_cls=dict(
                    type='FocalLoss',
                    use_sigmoid=True,
                    gamma=2.0,
                    alpha=0.25,
                    loss_weight=1.5),
                loss_dn_bbox=dict(type='L1Loss', loss_weight=0.025),
                loss_dn_lane_type=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=True,
                    loss_weight=0.1),
                loss_dn_mask=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=True,
                    reduction='mean',
                    loss_weight=3.0),
                loss_dn_dice=dict(
                    type='DiceLoss',
                    use_sigmoid=True,
                    activate=True,
                    reduction='mean',
                    naive_dice=True,
                    eps=1.0,
                    loss_weight=3.0),
                loss_plan_reg = dict(type='L1Loss', loss_weight=1.0),
                loss_reg=dict(
                    type='LinesL1Loss',
                    loss_weight=50.0,
                    beta=0.01,
                ),
                 train_cfg=dict(
                     assigner=dict(
                         type='HungarianAssigner',
                         cls_cost=dict(type='ClassificationCost', weight=1.),
                         reg_cost=dict(type='BBoxL1Cost', weight=5.0),
                         iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0)
                     )),
                 test_cfg=dict(max_per_img=100),
                 init_cfg=None,
                 pred_mask=False,
                 **kwargs):
        # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
        # since it brings inconvenience when the initialization of
        # `AnchorFreeHead` is called.
        super(AnchorFreeHead, self).__init__(init_cfg)
        self.bg_cls_weight = 0
        self.sync_cls_avg_factor = sync_cls_avg_factor
        if train_cfg:
            assert 'assigner' in train_cfg, 'assigner should be provided '\
                'when train_cfg is set.'
            assigner = train_cfg['assigner']
            assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \
                'The classification weight for loss and matcher should be' \
                'exactly the same.'
            assert loss_bbox['loss_weight'] == assigner['reg_cost'][
                'weight'], 'The regression L1 weight for loss and matcher ' \
                'should be exactly the same.'

            self.assigner = build_assigner(assigner)
            # DETR sampling=False, so use PseudoSampler
            sampler_cfg = dict(type='PseudoSampler')
            self.sampler = build_sampler(sampler_cfg, context=self)
        self.num_query = num_query
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.fp16_enabled = False
        self.loss_cls = build_loss(loss_cls)
        self.loss_bbox = build_loss(loss_bbox)
        self.loss_lane_type = build_loss(loss_lane_type)
        self.loss_mask = build_loss(loss_mask)
        self.loss_dice = build_loss(loss_dice)

        self.loss_plan_rec = nn.MSELoss(reduction='none')

        self.loss_reg = build_loss(loss_reg)
        self.pred_mask = pred_mask
        self.loss_mask_type = loss_mask['type']

        if self.loss_cls.use_sigmoid:
            self.cls_out_channels = num_classes
        else:
            self.cls_out_channels = num_classes + 1
        
        if loss_lane_type.use_sigmoid:
            self.cls_lane_type_out_channels = num_lane_type_classes
        else:
            self.cls_lane_type_out_channels = num_lane_type_classes + 1

        self.act_cfg = transformer.get('act_cfg',
                                       dict(type='ReLU', inplace=True))
        self.activate = build_activation_layer(self.act_cfg)
        self.transformer = build_transformer(transformer)
        self.embed_dims = self.transformer.embed_dims

        self.bev_h = bev_h
        self.bev_w = bev_w
        self.fp16_enabled = False

        assert pts_dim in (2, 3)
        self.pts_dim = pts_dim

        self.with_box_refine = with_box_refine
        if with_shared_param is not None:
            self.with_shared_param = with_shared_param
        else:
            self.with_shared_param = not self.with_box_refine
        self.as_two_stage = False

        if 'code_size' in kwargs:
            self.code_size = kwargs['code_size']
        else:
            self.code_size = pts_dim * 30
        if code_weights is not None:
            self.code_weights = code_weights
        else:
            self.code_weights = [1.0, ] * self.code_size
        self.code_weights = nn.Parameter(torch.tensor(
            self.code_weights, requires_grad=False), requires_grad=False)
        self.gt_c_save = self.code_size

        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.pc_range = pc_range
        self.real_w = self.pc_range[3] - self.pc_range[0]
        self.real_h = self.pc_range[4] - self.pc_range[1]
        self.num_reg_fcs = num_reg_fcs
        self.num_lane_type_classes = num_lane_type_classes


        ### stream
        self.iter = 0
        self.dn_iter = dn_iter
        self.stream_dn = stream_dn
        self.tolerant_noise = tolerant_noise
        self.noise_decay_scale = noise_decay_scale
        self.chamfer_thresh = chamfer_thresh
        self.dn_cfg = dn_cfg
        self.num_queries = num_query
        self.dn_cls_num = dn_cls_num
        self.num_points = num_points
        if streaming_cfg:
            self.streaming_query = streaming_cfg['streaming']

        else:
            self.streaming_query = False
        
        if self.streaming_query:
            self.batch_size = streaming_cfg['batch_size']
            self.topk_query = streaming_cfg['topk']
            self.trans_loss_weight = streaming_cfg.get('trans_loss_weight', 0.0)
            self.query_memory = StreamTensorMemory(
                self.batch_size,
            )

            self.reference_points_memory = StreamTensorMemory(
                self.batch_size,
            )
            self.lane_ref_points_memory = StreamTensorMemory(
                self.batch_size,
            )
            self.lane_id_memory = StreamListMemory(
                self.batch_size,
            ) ###用于判断上下帧同一个实例，来作为转换损失的监督

            c_dim = 12

            self.query_update = LaneSegMotionMLP(c_dim=c_dim, f_dim=self.embed_dims, identity=True)
            self.target_memory = StreamTensorMemory(self.batch_size)
            self.target_label_memory = StreamTensorMemory(self.batch_size)
            self.target_left_type_memory = StreamTensorMemory(self.batch_size)
            self.target_right_type_memory = StreamTensorMemory(self.batch_size)
            self.target_mask_memory = StreamTensorMemory(self.batch_size)

            
        self.register_buffer('roi_size', torch.tensor(roi_size, dtype=torch.float32))
        origin = (-roi_size[0]/2, -roi_size[1]/2, self.pc_range[2])
        self.register_buffer('origin', torch.tensor(origin, dtype=torch.float32))
        self.map_size = [-50, -25, 50, 25]

        ###world model
        # self.num_traj_modal = num_traj_modal
        if self.streaming_query and wm_bev_constructor is not None:
            self.wm_bev_constructor = build_wm_bev_constructor(wm_bev_constructor)

        ##init
        self._init_layers()

    def _init_layers(self):
        cls_branch = []
        for _ in range(self.num_reg_fcs):
            cls_branch.append(Linear(self.embed_dims, self.embed_dims))
            cls_branch.append(nn.LayerNorm(self.embed_dims))
            cls_branch.append(nn.ReLU(inplace=True))
        cls_branch.append(Linear(self.embed_dims, self.cls_out_channels))
        fc_cls = nn.Sequential(*cls_branch)

        reg_branch = []
        for _ in range(self.num_reg_fcs):
            reg_branch.append(Linear(self.embed_dims, self.embed_dims))
            reg_branch.append(nn.ReLU())
        reg_branch.append(Linear(self.embed_dims, self.code_size // 3))
        reg_branch = nn.Sequential(*reg_branch)

        reg_branch_offset = []
        for _ in range(self.num_reg_fcs):
            reg_branch_offset.append(Linear(self.embed_dims, self.embed_dims))
            reg_branch_offset.append(nn.ReLU())
        reg_branch_offset.append(Linear(self.embed_dims, self.code_size // 3))
        reg_branch_offset = nn.Sequential(*reg_branch_offset)

        def _get_clones(module, N):
            return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

        self.query_embedding = nn.Embedding(self.num_query, self.embed_dims * 2) ###512一半content 一半PE

        cls_left_type_branch = []
        for _ in range(self.num_reg_fcs):
            cls_left_type_branch.append(Linear(self.embed_dims, self.embed_dims))
            cls_left_type_branch.append(nn.LayerNorm(self.embed_dims))
            cls_left_type_branch.append(nn.ReLU(inplace=True))
        cls_left_type_branch.append(Linear(self.embed_dims, self.cls_lane_type_out_channels))
        fc_cls_left_type = nn.Sequential(*cls_left_type_branch)        

        cls_right_type_branch = []
        for _ in range(self.num_reg_fcs):
            cls_right_type_branch.append(Linear(self.embed_dims, self.embed_dims))
            cls_right_type_branch.append(nn.LayerNorm(self.embed_dims))
            cls_right_type_branch.append(nn.ReLU(inplace=True))
        cls_right_type_branch.append(Linear(self.embed_dims, self.cls_lane_type_out_channels))
        fc_cls_right_type = nn.Sequential(*cls_right_type_branch)

        mask_branch = nn.Sequential(
            nn.Linear(self.embed_dims, self.embed_dims), nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims, self.embed_dims), nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims, self.embed_dims))

        num_pred = (self.transformer.decoder.num_layers + 1) if \
            self.as_two_stage else self.transformer.decoder.num_layers

        if not self.with_shared_param:
            self.cls_branches = _get_clones(fc_cls, num_pred)
            self.reg_branches = _get_clones(reg_branch, num_pred)
            self.reg_branches_offset = _get_clones(reg_branch_offset, num_pred)
            self.cls_left_type_branches = _get_clones(fc_cls_left_type, num_pred)
            self.cls_right_type_branches = _get_clones(fc_cls_right_type, num_pred)
            self.mask_embed = _get_clones(mask_branch, num_pred)
        else:
            self.cls_branches = nn.ModuleList(
                [fc_cls for _ in range(num_pred)])
            self.reg_branches = nn.ModuleList(
                [reg_branch for _ in range(num_pred)])
            self.reg_branches_offset = nn.ModuleList(
                [reg_branch_offset for _ in range(num_pred)])
            self.cls_left_type_branches = nn.ModuleList(
                [fc_cls_left_type for _ in range(num_pred)])
            self.cls_right_type_branches = nn.ModuleList(
                [fc_cls_right_type for _ in range(num_pred)])
            self.mask_embed = nn.ModuleList(
                [mask_branch for _ in range(num_pred)])
        ### init stream

        self.reference_points_embed = nn.Linear(self.embed_dims, self.pts_dim) ### 
        xavier_init(self.reference_points_embed, distribution='uniform', bias=0.)

        if self.streaming_query:
            if isinstance(self.query_update, LaneSegMotionMLP):
                self.query_update.init_weights()
            if hasattr(self, 'query_alpha'):
                for m in self.query_alpha:
                    for param in m.parameters():
                        if param.dim() > 1:
                            nn.init.zeros_(param)


        wm_decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.embed_dims,
            nhead=8,
            dim_feedforward=1024,
            dropout=0.1,
            batch_first=True,
        )
        self._query_wm_decoder = nn.TransformerDecoder(wm_decoder_layer, 2) 
        self.action_aware_bev_encoder = nn.Sequential(
            nn.Linear(self.embed_dims + 6*2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims, self.embed_dims)
        )

        self.action_aware_query_encoder = nn.Sequential(
            nn.Linear(self.embed_dims + 6*2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims, self.embed_dims)
        )

        self.down_sample = nn.AvgPool2d(kernel_size=2, stride=2)
        
    def _forward_mask_head(self, output, bev_feats, lvl):
        # shape (bs, num_query, embed_dims)
        bev_feats = bev_feats.view([bev_feats.shape[0], self.bev_h, self.bev_w, self.embed_dims])
        bev_feats = bev_feats.permute(0, 3, 1, 2).contiguous()
        mask_embed = self.mask_embed[lvl](output) #torch.Size([2, 200, 256])
        outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, bev_feats) #torch.Size([2, 200, 100, 200]) 

        return outputs_mask


    def propagate_worldmodel_pre(self, query_embedding, bev_feats, img_metas, same_id_gt_lane = None, return_loss=True):
        bs = query_embedding.shape[0]
        propagated_query_list = []
        prop_reference_points_list = []
        prop_centerline_reference_points_list= []
        prev2curr_matrix_list = []


        tmp = self.query_memory.get(img_metas)
        query_memory, pose_memory = tmp['tensor'], tmp['img_metas']

        tmp = self.reference_points_memory.get(img_metas)
        ref_pts_memory, pose_memory = tmp['tensor'], tmp['img_metas']

        tmp = self.lane_ref_points_memory.get(img_metas)
        lane_ref_pts_memory, pose_memory = tmp['tensor'], tmp['img_metas']
        

        if return_loss:
            target_memory = self.target_memory.get(img_metas)['tensor']
            target_label_memory = self.target_label_memory.get(img_metas)['tensor']
            target_left_type_memory = self.target_left_type_memory.get(img_metas)['tensor']
            target_right_type_memory = self.target_right_type_memory.get(img_metas)['tensor']
            trans_loss = query_embedding.new_zeros((1,))
            trans_cls_loss = query_embedding.new_zeros((1,))
            trans_left_type = query_embedding.new_zeros((1,))
            trans_right_type = query_embedding.new_zeros((1,))
            trans_loss_dice = query_embedding.new_zeros((1,))
            trans_loss_mask = query_embedding.new_zeros((1,))
            num_pos = 0

        is_first_frame_list = tmp['is_first_frame']

        for i in range(bs):
            is_first_frame = is_first_frame_list[i]
            if is_first_frame:                         ###33            256
                padding = query_embedding.new_zeros((self.topk_query, self.embed_dims))
                propagated_query_list.append(padding)

            else:
                # use float64 to do precise coord transformation lidar2global_translation
                prev_e2g_trans = self.roi_size.new_tensor(pose_memory[i]['lidar2global_translation'], dtype=torch.float64)
                prev_e2g_rot = self.roi_size.new_tensor(pose_memory[i]['lidar2global_rotation'], dtype=torch.float64)
                curr_e2g_trans = self.roi_size.new_tensor(img_metas[i]['lidar2global_translation'], dtype=torch.float64)
                curr_e2g_rot = self.roi_size.new_tensor(img_metas[i]['lidar2global_rotation'], dtype=torch.float64)
                
                prev_e2g_matrix = torch.eye(4, dtype=torch.float64).to(query_embedding.device)
                prev_e2g_matrix[:3, :3] = prev_e2g_rot
                prev_e2g_matrix[:3, 3] = prev_e2g_trans

                curr_g2e_matrix = torch.eye(4, dtype=torch.float64).to(query_embedding.device)
                curr_g2e_matrix[:3, :3] = curr_e2g_rot.T
                curr_g2e_matrix[:3, 3] = -(curr_e2g_rot.T @ curr_e2g_trans)

                prev2curr_matrix = curr_g2e_matrix @ prev_e2g_matrix
                pos_encoding = prev2curr_matrix.float()[:3].view(-1)
            

                prev2curr_matrix_list.append(prev2curr_matrix)

                prop_q = query_memory[i].unsqueeze(0)

                batch_size, num_tokens, num_channel = prop_q.shape #1 5000 256
                pos_encoding = pos_encoding.view(1, -1).repeat(1, num_tokens, 1)

                pre_query_feat_with_trajs = torch.cat([prop_q, pos_encoding], dim=-1) 
                action_aware_query_latent = self.action_aware_query_encoder (pre_query_feat_with_trajs) #torch.Size([1, 20000, 256])
                query_memory_updated = self._query_wm_decoder(action_aware_query_latent, action_aware_query_latent)
                query_memory_updated = query_memory_updated.squeeze(0)

                # query_memory_updated = self.query_update(
                #     prop_q, # (topk, embed_dims)
                #     pos_encoding.view(1, -1).repeat(len(query_memory[i]), 1)
                # ) ##topk, 256
                propagated_query_list.append(query_memory_updated.clone())

                pred = self.reg_branches[-1](query_memory_updated)# (66, 30)
                pred_offset = self.reg_branches_offset[-1](query_memory_updated)
                pred_offset = pred_offset.view(self.topk_query, self.num_points, 3)
                assert list(pred.shape) == [self.topk_query, 3*self.num_points]

                # ref pts
                prev_ref_pts = ref_pts_memory[i][-1,:,:,:] ##取最后第二层的ref pts,因为最后一层的就是中心线ref
                assert list(prev_ref_pts.shape) == [self.topk_query, self.num_points, 3]

                ###修改后
                denormed_last_layer_prev_ref_pts = prev_ref_pts * self.roi_size + self.origin

                denormed_last_layer_prev_ref_pts = torch.cat([
                    denormed_last_layer_prev_ref_pts,
                    denormed_last_layer_prev_ref_pts.new_ones((self.topk_query, 10, 1)) # 4-th dim
                ], dim=-1) # (num_prop, num_pts, 4)
                normed_last_layer_prev2curr_ref_pts = torch.einsum('lk,ijk->ijl', prev2curr_matrix, denormed_last_layer_prev_ref_pts.double()).float()

                normed_last_layer_prev2curr_ref_pts = (normed_last_layer_prev2curr_ref_pts[..., :3] - self.origin) / self.roi_size # (num_prop, num_pts, 2)
                normed_last_layer_prev2curr_ref_pts = torch.clip(normed_last_layer_prev2curr_ref_pts, min=0., max=1.)
                prop_centerline_reference_points_list.append(normed_last_layer_prev2curr_ref_pts) ##中心线ref转到了本帧给到后续


                ###获得pred的绝对位置
                pred = pred.view(self.topk_query, self.num_points, 3)
                # pred = pred + normed_last_layer_prev2curr_ref_pts  ###最后一层还需要加吗?
                pred = pred + inverse_sigmoid(normed_last_layer_prev2curr_ref_pts)  ###修改后
                pred = pred.sigmoid().clone()
                denormed_pred = pred * self.roi_size + self.origin
                denormed_pred = denormed_pred.contiguous()
                denormed_pred_left = denormed_pred + pred_offset
                denormed_pred_right = denormed_pred - pred_offset

                normed_pred = torch.cat([denormed_pred, denormed_pred_left, denormed_pred_right], dim=1)


                assert list(normed_pred.shape) == [self.topk_query, self.num_points*3, 3]

                denormed_lane_ref_pts = lane_ref_pts_memory[i][-1,:,:,:] * self.roi_size + self.origin # (num_prop, num_pts, 2)
                # denormed_ref_pts = ref_pts_memory[i][1,:,:,:]
                assert list(denormed_lane_ref_pts.shape) == [self.topk_query, 8, 3]
                denormed_lane_ref_pts = torch.cat([
                    denormed_lane_ref_pts,
                    denormed_lane_ref_pts.new_ones((self.topk_query, 8, 1)) # 4-th dim
                ], dim=-1) # (num_prop, num_pts, 4)
                assert list(denormed_lane_ref_pts.shape) == [self.topk_query, 8, 4]

                curr_lane_ref_pts = torch.einsum('lk,ijk->ijl', prev2curr_matrix, denormed_lane_ref_pts.double()).float()
                normed_lane_ref_pts = (curr_lane_ref_pts[..., :3] - self.origin) / self.roi_size # (num_prop, num_pts, 2)


                normed_lane_ref_pts = torch.clip(normed_lane_ref_pts, min=0., max=1.)

                prop_reference_points_list.append(normed_lane_ref_pts) ##给的是左右线参考点，并且是最后一层的

                if return_loss:

                    targets = target_memory[i]###上一帧的gt lane
                                                ###topk 30 3
                    target_labels = target_label_memory[i]
                    target_left_types = target_left_type_memory[i]
                    target_right_types = target_right_type_memory[i]

                    same_id_gt_lane_single = same_id_gt_lane[i]

                    weights = targets.new_ones((self.topk_query, 9*self.num_points)) #66, 30
                    label_weights = torch.ones_like(target_labels).to(torch.float32)
                    bg_idx = torch.all(targets.view(self.topk_query, -1) == 0.0, dim=1)
                    num_pos = (self.topk_query - bg_idx.sum())
                    weights[bg_idx, :] = 0.0
                    pos_idx = torch.nonzero(~bg_idx, as_tuple=True)[0]
                    label_weights[bg_idx] = 0.0

                    denormed_targets = targets ###本来就是denorm的
                    denormed_targets = torch.cat([
                        denormed_targets,
                        denormed_targets.new_ones((self.topk_query, 3* self.num_points, 1)) # 4-th dim
                    ], dim=-1) # (num_prop, num_pts, 4)
                    assert list(denormed_targets.shape) == [self.topk_query, 3* self.num_points, 4]
                    curr_targets = torch.einsum('lk,ijk->ijl', prev2curr_matrix.float(), denormed_targets)


                    normed_targets = curr_targets[..., :3]

                    ##转换损失
                    if isinstance(same_id_gt_lane_single, torch.Tensor):  ##必须有跟踪的目标才能
                        for idx, pos_id in enumerate(pos_idx):
                            if torch.all(same_id_gt_lane_single[idx] == 0.0):
                                continue
                            else:
                                normed_targets[pos_id] = same_id_gt_lane_single.reshape(len(same_id_gt_lane_single),30,3)[idx]

                    elif same_id_gt_lane_single == [None]:
                        pass
                    else:
                        print('same_id_gt_lane_single',same_id_gt_lane_single)
                        raise RuntimeError(f"unexpected value in same_id_gt_lane_single")
  
                    trans_loss += self.loss_reg(normed_pred.reshape(-1,9*self.num_points), normed_targets.reshape(-1,9*self.num_points), weights, avg_factor=num_pos)

                    output_class = self.cls_branches[-1](query_memory_updated)
                    output_left_type = self.cls_left_type_branches[-1](query_memory_updated)
                    output_right_type = self.cls_right_type_branches[-1](query_memory_updated)

                    trans_cls_loss += self.loss_cls(output_class, target_labels, label_weights, avg_factor=num_pos)
                    trans_left_type += self.loss_lane_type(output_left_type, target_left_types, label_weights, avg_factor=num_pos)
                    trans_right_type += self.loss_lane_type(output_right_type, target_right_types, label_weights, avg_factor=num_pos)
                    
                    pos_target_lanes = normed_targets[~bg_idx].cpu().numpy()

                    if len(pos_target_lanes)!=0:
                        trans_mask = self.generate_lanesegment_instance_mask(pos_target_lanes,device = output_class.device)
                        outputs_mask = self._forward_mask_head(query_memory_updated[~bg_idx].unsqueeze(0), bev_feats[i].unsqueeze(0), -1)
                        outputs_mask = outputs_mask.squeeze(0)
                        trans_loss_dice += self.loss_dice(
                        outputs_mask, trans_mask, avg_factor=num_pos
                        )
                        outputs_mask = outputs_mask.reshape(-1, 1)
                        trans_mask = trans_mask.reshape(-1)

                        h, w = outputs_mask.shape[-2:]

                        if self.loss_mask_type == 'FocalLoss':
                            trans_mask = (1 - trans_mask).long()
                        if self.loss_mask_type == 'CrossEntropyLoss':
                            trans_mask = trans_mask.reshape(outputs_mask.shape).bool()

                        trans_loss_mask += self.loss_mask(
                            outputs_mask, trans_mask, avg_factor=num_pos * h * w
                        )

        prop_query_embedding = torch.stack(propagated_query_list) # (bs, topk, embed_dims)
        prop_ref_pts = torch.stack(prop_reference_points_list) # (bs, topk, num_pts, 2) ##上一帧的ref pts投影到这一帧 要用来进行上一帧位置的查找
        prop_centerline_ref_pts = torch.stack(prop_centerline_reference_points_list)
        assert list(prop_query_embedding.shape) == [bs, self.topk_query, self.embed_dims]

        assert list(prop_ref_pts.shape) == [bs, self.topk_query, 8, 3]
        assert list(prop_centerline_ref_pts.shape) == [bs, self.topk_query, 10, 3]
        ####用一半的query 就是pos query

        query_pos, _ = torch.split(query_embedding, self.embed_dims, dim=2)
        init_reference_points = self.reference_points_embed(query_pos) # (bs, num_q, 2*num_pts)


        init_reference_points = init_reference_points.repeat(1, 1, self.num_points)
        init_reference_points = init_reference_points.sigmoid()
        bs, num_qeury, _ = init_reference_points.shape
        init_reference_points = init_reference_points.view(bs, num_qeury, self.num_points, self.pts_dim)
        memory_query_embedding = None

        if return_loss:
            trans_loss = self.trans_loss_weight * (trans_loss +  trans_cls_loss + trans_left_type + trans_right_type + trans_loss_dice + trans_loss_mask)
            return query_embedding, prop_query_embedding, init_reference_points, prop_ref_pts, prop_centerline_ref_pts, memory_query_embedding, is_first_frame_list, trans_loss, prev2curr_matrix_list
        else:
            return query_embedding, prop_query_embedding, init_reference_points, prop_ref_pts, prop_centerline_ref_pts, memory_query_embedding, is_first_frame_list

    def generate_lanesegment_instance_mask(self, results, device):
        gt_lanes = results.reshape(-1, 3, 10, 3)
        gt_left_lines = gt_lanes[:, 1]
        gt_right_lines = gt_lanes[:, 2]

        origin = np.array([self.bev_w // 2, self.bev_h // 2])
        scale = np.array([self.bev_w / (self.map_size[2] - self.map_size[0]), self.bev_h / (self.map_size[3] - self.map_size[1])])

        inst_masks = []
        for idx, (left_line, right_line) in enumerate(zip(gt_left_lines, gt_right_lines)):

            segment_boundary = np.concatenate((left_line, right_line[::-1], left_line[0:1]), axis=0)
            mask = np.zeros((self.bev_h, self.bev_w), dtype=np.uint8)

            draw_coor = (segment_boundary[:, :2] * scale + origin).astype(np.int32)
            mask = cv2.fillPoly(mask, [draw_coor], 255)
            bitMask = (mask / 255)
            bitMask = torch.tensor(bitMask).to(device=device)
            inst_masks.append(bitMask)
        inst_masks = torch.stack(inst_masks, 0)
  

        return inst_masks
    
    @auto_fp16(apply_to=('mlvl_feats'))
    def forward_train(self, mlvl_feats, bev_feats, img_metas,gt_lanes_3d = None, gt_lane_labels_3d = None, gt_instance_masks = None, gt_lane_left_type = None, gt_lane_right_type = None, future_data = None, gt_ego_fut_cmd = None):
        """Forward function.
        Args:
            mlvl_feats (tuple[Tensor]): Features from the upstream
                network, each is a 5D-tensor with shape
                (B, N, C, H, W).
            prev_bev: previous bev featues
            only_bev: only compute BEV features with encoder. 
        Returns:
            all_cls_scores (Tensor): Outputs from the classification head, \
                shape [nb_dec, bs, num_query, cls_out_channels]. Note \
                cls_out_channels should includes background.
            all_lanes_preds (Tensor): Sigmoid outputs from the regression \
                head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
                Shape [nb_dec, bs, num_query, 99].
            all_mask_preds (Tensor): Sigmoid outputs from the segmentation \
                head with normalized value in the range of [0,1].
                Shape []
        """

        bs = len(img_metas)
        dtype = mlvl_feats[0].dtype
        # object_query_embeds = self.query_embedding.weight.to(dtype)
        
        object_query_embeds = self.query_embedding.weight.to(dtype)[None, ...].repeat(bs, 1, 1)

        if self.streaming_query:
            same_id_gt_lane = []
            tmp = self.lane_id_memory.get(img_metas)
            last_lane_id, pose_memory = tmp['tensor'], tmp['img_metas'] ##上一帧中线段id
            match_current_gt_target = None

            query_embedding, prop_query_embedding, init_reference_points, prop_ref_pts, prop_centerline_ref_pts, memory_query, is_first_frame_list, trans_loss, prev2curr_matrix_list = \
                self.propagate_worldmodel_pre(object_query_embeds, bev_feats, img_metas, same_id_gt_lane, return_loss=True)
        else:
            query_pos, _ = torch.split(object_query_embeds, self.embed_dims, dim=2)


            init_reference_points = self.reference_points_embed(query_pos) # torch.Size([2, 200, 256])-> torch.Size([2, 200, 3])
            #####这里用的lanesegnet全部相同初始化
            init_reference_points = init_reference_points.repeat(1, 1, self.num_points) #torch.Size([2, 200, 3])-> torch.Size([2, 200, 30])
            init_reference_points = init_reference_points.sigmoid()
            bs, num_qeury, _ = init_reference_points.shape
            init_reference_points = init_reference_points.view(bs, num_qeury, self.num_points, self.pts_dim)
            prop_query_embedding = None
            prop_ref_pts = None
            prop_centerline_ref_pts = None
            is_first_frame_list = [True for i in range(bs)]
            query_embedding = object_query_embeds


        assert list(init_reference_points.shape) == [bs, self.num_queries, self.num_points, self.pts_dim] #[1, 200, 10, 3]



        stream_self_attn_mask = None
        stream_dn_query = None
        stream_denoise_refers = None
        dn_num = 0


        outputs = self.transformer(
            mlvl_feats,
            bev_feats,
            object_query_embeds,
            reference_points = init_reference_points,
            prop_query = prop_query_embedding,
            prop_reference_points = prop_ref_pts,
            prop_centerline_reference_points = prop_centerline_ref_pts,
            is_first_frame_list=is_first_frame_list,
            # query_key_padding_mask=query_embedding.new_zeros((bs, self.num_queries), dtype=torch.bool), # mask used in self-attn,
            bev_h=self.bev_h,
            bev_w=self.bev_w,
            reg_branches=(self.reg_branches, self.reg_branches_offset) if self.with_box_refine else None,  # noqa:E501
            cls_branches=self.cls_branches,
            img_metas=img_metas
        )

        ###每一层query 第一层相同的reference 中间层normed的中心线位置 左右线采样点中间层reference pts
        init_hs, inter_hs, init_reference, init_inter_reference_points, inter_references, init_inter_lane_reference_points, inter_lanepts_references = outputs
        init_hs = init_hs.permute(0, 2, 1, 3)
        inter_hs = inter_hs.permute(0, 2, 1, 3)

        outputs_classes = []
        outputs_coord = []
        outputs_masks = []
        output_left_types = []
        output_right_types = []
        for lvl in range(len(self.cls_branches)):
            if lvl == 0:
                reference = init_reference[:, dn_num:,:]
                hs = init_hs[lvl]
       
                bev_feats_mask_input = bev_feats[:bs,:,:]

            if lvl ==1 and self.streaming_query:
                reference = init_inter_reference_points
                hs = inter_hs[lvl-1]
                bev_feats_mask_input = bev_feats
            elif lvl ==1:
                reference = inter_references[lvl-1]
                hs = inter_hs[lvl-1]
                bev_feats_mask_input = bev_feats
            if lvl !=0 and lvl !=1 and self.streaming_query:
                reference = inter_references[lvl-2]
                hs = inter_hs[lvl-1]
                bev_feats_mask_input = bev_feats
            elif lvl !=0 and lvl !=1:
                reference = inter_references[lvl-1]
                hs = inter_hs[lvl-1]
                bev_feats_mask_input = bev_feats

            reference = inverse_sigmoid(reference)
            assert reference.shape[-1] == self.pts_dim

            outputs_class = self.cls_branches[lvl](hs[:, dn_num:,:])
            output_left_type = self.cls_left_type_branches[lvl](hs[:, dn_num:,:])
            output_right_type = self.cls_right_type_branches[lvl](hs[:, dn_num:,:])

            tmp = self.reg_branches[lvl](hs[:, dn_num:,:])
            bs, num_query, _ = tmp.shape
            tmp = tmp.view(bs, num_query, -1, self.pts_dim)
            tmp = tmp + reference
            tmp = tmp.sigmoid()

            coord = tmp.clone()
            coord[..., 0] = coord[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0] ##*(51.2 - -51.2) + -51.2
            coord[..., 1] = coord[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
            if self.pts_dim == 3:
                coord[..., 2] = coord[..., 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2]
            centerline = coord.view(bs, num_query, -1).contiguous()

            offset = self.reg_branches_offset[lvl](hs[:, dn_num:,:])
            left_laneline = centerline + offset
            right_laneline = centerline - offset

            # segmentation head
            outputs_mask = self._forward_mask_head(hs[:, dn_num:,:], bev_feats_mask_input, lvl)

            outputs_classes.append(outputs_class)
            outputs_coord.append(torch.cat([centerline, left_laneline, right_laneline], axis=-1))
            outputs_masks.append(outputs_mask)
            output_left_types.append(output_left_type)
            output_right_types.append(output_right_type)

     
        if self.streaming_query:
            other_layer_outputs_classes = torch.stack(outputs_classes[1:])
            other_layer_outputs_coords = torch.stack(outputs_coord[1:])
            other_layer_outputs_masks = torch.stack(outputs_masks[1:])
            other_layer_output_left_types = torch.stack(output_left_types[1:])
            other_layer_output_right_types = torch.stack(output_right_types[1:])   

            outs_first_layer = {
                'all_cls_scores': outputs_classes[0].unsqueeze(0),
                'all_lanes_preds': outputs_coord[0].unsqueeze(0),
                'all_mask_preds': outputs_masks[0].unsqueeze(0),
                'all_lanes_left_type': output_left_types[0].unsqueeze(0),
                'all_lanes_right_type': output_right_types[0].unsqueeze(0),
                'history_states': init_hs[:,:, dn_num:,:]
            }

            outs_other_layer = {
                'all_cls_scores': other_layer_outputs_classes,
                'all_lanes_preds': other_layer_outputs_coords,
                'all_mask_preds': other_layer_outputs_masks,
                'all_lanes_left_type': other_layer_output_left_types,
                'all_lanes_right_type': other_layer_output_right_types,
                'history_states': inter_hs[:,:, dn_num:,:]
            }

            first_layer_loss_inputs = [outs_first_layer, gt_lanes_3d, gt_lane_labels_3d, gt_instance_masks, gt_lane_left_type, gt_lane_right_type]
            
  
            first_layer_lane_losses, first_layer_lane_assign_result, first_layer_bbox_targets, first_layer_labels_targets, first_layer_labels_left_type_targets, first_layer_labels_right_type_targets, first_layer_mask_targets, first_layer_mask_targets_weights = self.loss(*first_layer_loss_inputs, img_metas=img_metas)
            


            gt_lanes_3d.extend(gt_lanes_3d)
            gt_lane_labels_3d.extend(gt_lane_labels_3d)
            gt_instance_masks.extend(gt_instance_masks)
            gt_lane_left_type.extend(gt_lane_left_type)
            gt_lane_right_type.extend(gt_lane_right_type)

            other_layer_loss_inputs = [outs_other_layer, gt_lanes_3d, gt_lane_labels_3d, gt_instance_masks, gt_lane_left_type, gt_lane_right_type, first_layer_lane_losses]
            
            lane_losses, other_layer_lane_assign_result, other_layer_bbox_targets, other_layer_labels_targets, other_layer_labels_left_type_targets, other_layer_labels_right_type_targets, other_layer_mask_targets, other_layer_mask_targets_weights = self.loss(*other_layer_loss_inputs, img_metas=img_metas)
        
        else:
            outputs_classes = torch.stack(outputs_classes)
            outputs_coords = torch.stack(outputs_coord)
            outputs_masks = torch.stack(outputs_masks)
            output_left_types = torch.stack(output_left_types)
            output_right_types = torch.stack(output_right_types)  

            outs = {
                'all_cls_scores': outputs_classes,
                'all_lanes_preds': outputs_coords,
                'all_mask_preds': outputs_masks,
                'all_lanes_left_type': output_left_types,
                'all_lanes_right_type': output_right_types,
                'history_states': inter_hs[:,:, dn_num:,:]
            } 

            loss_inputs = [outs, gt_lanes_3d, gt_lane_labels_3d, gt_instance_masks, gt_lane_left_type, gt_lane_right_type]
            lane_losses, lane_assign_result, bbox_targets, labels_targets, labels_left_type_targets, labels_right_type_targets, mask_targets, mask_targets_weights = self.loss(*loss_inputs, img_metas=img_metas)



        if self.streaming_query:
            query_list = []
            ref_pts_list = [] ##中心线reference pts
            lane_ref_pts_list = [] ##左右线ref pts
            gt_targets_list = []
            gt_label_target_list = []
            gt_left_type_target_list = []
            gt_right_type_target_list = []
            gt_lane_id_list = []
            gt_lane_id_list_all = []
            bs = len(img_metas)


            lines, scores = outs_other_layer['all_lanes_preds'][-1][bs:,:,:], outs_other_layer['all_cls_scores'][-1][bs:,:,:] ##后两个才是时序 前两个是单帧
            gt_lines = other_layer_bbox_targets[-1] # take results from the last layer

            
            gt_lines = gt_lines[self.num_query*bs:,:].reshape(bs, self.num_query, 90)
            gt_label_target = other_layer_labels_targets[-1]
            gt_label_target = gt_label_target[self.num_query*bs:].reshape(bs, self.num_query)
            gt_left_type_target = other_layer_labels_left_type_targets[-1]
            gt_left_type_target = gt_left_type_target[self.num_query*bs:].reshape(bs, self.num_query)
            gt_right_type_target = other_layer_labels_right_type_targets[-1]
            gt_right_type_target = gt_right_type_target[self.num_query*bs:].reshape(bs, self.num_query)


            last_lane_assign_result = other_layer_lane_assign_result[-1]

            for i in range(bs):

                inter_queries = inter_hs[:, bs:, :, :]
                _lines = lines[i] #torch.Size([200, 90])
                _queries = inter_queries[-1][i][dn_num:] ## 200, 256
                _scores = scores[i] #torch.Size([200, 2])
                _gt_targets = gt_lines[i] # (num_q or num_q+topk, 20, 2)
                _gt_label_target = gt_label_target[i]
                _gt_left_type_target = gt_left_type_target[i]
                _gt_right_type_target = gt_right_type_target[i]

 
                _inter_reference = inter_references[:,bs:,:,:,:]
                _inter_reference = _inter_reference[:,i,dn_num:,:,:]
                _inter_lane_reference = inter_lanepts_references[:,bs:,:,:,:]
                _inter_lane_reference = _inter_lane_reference[:,i,dn_num:,:,:]
                _last_lane_assign_result_pos_inds= last_lane_assign_result['pos_inds'][bs:][i] ##200个中哪些是匹配上的
                _last_lane_assign_result_pos_assigned_gt_inds= last_lane_assign_result['pos_assigned_gt_inds'][bs:][i] ##gt再pos中的顺序包括人行道 要去掉
                _gt_lane_id = img_metas[i]['lane_id']
                assert len(_lines) == len(_queries)
                assert len(_lines) == len(_gt_targets)

                _scores, _ = _scores.max(-1)
                topk_score, topk_idx = _scores.topk(k=self.topk_query, dim=-1) #topk = 66

                


                _queries = _queries[topk_idx] # (topk, embed_dims)
                _lines = _lines[topk_idx] # (topk, 2*num_pts)

                _gt_lane_id2target_id = [_gt_lane_id[i] if i < len(_gt_lane_id) else None
                for i in _last_lane_assign_result_pos_assigned_gt_inds.cpu().tolist()
                ] ###得到了经过匹配后gt id的顺序

                _mask_pos_in_topk = (_last_lane_assign_result_pos_inds.unsqueeze(1) == topk_idx).any(dim=1)
                _pos_in_topk = _last_lane_assign_result_pos_inds[_mask_pos_in_topk]

                # 找到 list2 每个元素在 list1 中的索引
                indices_sort_after_topk = torch.tensor([torch.where(topk_idx == x)[0].item() for x in _pos_in_topk])

                # 根据索引对 list2 进行排序
                sorted_indices = torch.argsort(indices_sort_after_topk)
                _pos_in_topk_sort = _pos_in_topk[sorted_indices]

                indices_to_keep = torch.nonzero(_mask_pos_in_topk, as_tuple=True)[0].cpu().tolist()
                _last_lane_assign_result_pos_assigned_gt_inds_topk = _last_lane_assign_result_pos_assigned_gt_inds[_mask_pos_in_topk]
                _gt_lane_id2target_id = [_gt_lane_id2target_id[i] for i in indices_to_keep]
                _gt_lane_id2target_id = [_gt_lane_id2target_id[i] for i in sorted_indices.tolist()]
                #gt_lanes_3d

                _gt_targets = _gt_targets[topk_idx] # (topk, 20, 2)

                _gt_label_target = _gt_label_target[topk_idx]
                _gt_left_type_target = _gt_left_type_target[topk_idx]
                _gt_right_type_target = _gt_right_type_target[topk_idx]

                _inter_reference = _inter_reference[:,topk_idx,:,:]#layer, topk, 10, 3
                _inter_lane_reference = _inter_lane_reference[:,topk_idx,:,:]#layer, topk, 8, 3
                _all_reference =  _inter_reference



            lane_losses['trans_loss'] = trans_loss
        
        self.iter += 1

        if self.streaming_query:
            return outs_first_layer, outs_other_layer, lane_losses, first_layer_lane_assign_result, other_layer_lane_assign_result
        else:
            return outs,  lane_losses, lane_assign_result

 

    @auto_fp16(apply_to=('mlvl_feats'))
    def forward_test(self, mlvl_feats, bev_feats, img_metas):
        """Forward function.
        Args:
            mlvl_feats (tuple[Tensor]): Features from the upstream
                network, each is a 5D-tensor with shape
                (B, N, C, H, W).
            prev_bev: previous bev featues
            only_bev: only compute BEV features with encoder. 
        Returns:
            all_cls_scores (Tensor): Outputs from the classification head, \
                shape [nb_dec, bs, num_query, cls_out_channels]. Note \
                cls_out_channels should includes background.
            all_lanes_preds (Tensor): Sigmoid outputs from the regression \
                head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
                Shape [nb_dec, bs, num_query, 99].
            all_mask_preds (Tensor): Sigmoid outputs from the segmentation \
                head with normalized value in the range of [0,1].
                Shape []
        """

        bs = len(img_metas)
        dtype = mlvl_feats[0].dtype
        object_query_embeds = self.query_embedding.weight.to(dtype)[None, ...].repeat(bs, 1, 1)

        if self.streaming_query:

            query_embedding, prop_query_embedding, init_reference_points, prop_ref_pts, prop_centerline_ref_pts, memory_query, is_first_frame_list = \
                self.propagate_worldmodel_pre(object_query_embeds, bev_feats, img_metas, return_loss=False)

        else:
            query_pos, _ = torch.split(object_query_embeds, self.embed_dims, dim=1)
            init_reference_points = self.reference_points_embed(query_pos) # (bs, num_q, num_pts)
            init_reference_points = init_reference_points.repeat(1, 1, self.num_points)
            init_reference_points = init_reference_points.sigmoid()
            bs, num_qeury, _ = init_reference_points.shape
            init_reference_points = init_reference_points.view(bs, num_qeury, self.num_points, self.pts_dim)
            prop_query_embedding = None
            prop_ref_pts = None
            prop_centerline_ref_pts = None
            is_first_frame_list = [True for i in range(bs)]



        assert list(init_reference_points.shape) == [bs, self.num_queries, self.num_points, self.pts_dim] #[1, 200, 10, 3]



        outputs = self.transformer(
            mlvl_feats,
            bev_feats,
            object_query_embeds,
            reference_points = init_reference_points,
            prop_query = prop_query_embedding,
            prop_reference_points = prop_ref_pts,
            prop_centerline_reference_points = prop_centerline_ref_pts,
            is_first_frame_list=is_first_frame_list,
            # query_key_padding_mask=query_embedding.new_zeros((bs, self.num_queries), dtype=torch.bool), # mask used in self-attn,
            bev_h=self.bev_h,
            bev_w=self.bev_w,
            reg_branches=(self.reg_branches, self.reg_branches_offset) if self.with_box_refine else None,  # noqa:E501
            cls_branches=self.cls_branches,
            img_metas=img_metas
        )
        
        init_hs, inter_hs, init_reference, init_inter_reference_points, inter_references, init_inter_lane_reference_points, inter_lanepts_references = outputs
        hs = inter_hs.permute(0, 2, 1, 3)


        reference = inter_references[-1]
        reference = inverse_sigmoid(reference)
        assert reference.shape[-1] == self.pts_dim

        outputs_class = self.cls_branches[-1](hs[-1])
        output_left_type = self.cls_left_type_branches[-1](hs[-1])
        output_right_type = self.cls_right_type_branches[-1](hs[-1])

        tmp = self.reg_branches[-1](hs[-1])
        bs, num_query, _ = tmp.shape
        tmp = tmp.view(bs, num_query, -1, self.pts_dim)
        tmp = tmp + reference
        tmp = tmp.sigmoid()

        coord = tmp.clone()
        coord[..., 0] = coord[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
        coord[..., 1] = coord[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
        if self.pts_dim == 3:
            coord[..., 2] = coord[..., 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2]
        centerline = coord.view(bs, num_query, -1).contiguous()

        offset = self.reg_branches_offset[-1](hs[-1])
        left_laneline = centerline + offset
        right_laneline = centerline - offset

        # segmentation head
        if self.pred_mask:
            outputs_mask = self._forward_mask_head(hs[-1], bev_feats, -1)
        
        select_mode = 1 ## 0 for per-frame 1 for stream 2 for both
        if select_mode == 0 or select_mode == 1:

            outputs_classes = torch.stack([outputs_class[select_mode].unsqueeze(0)])
            outputs_coords = torch.stack([torch.cat([centerline[select_mode].unsqueeze(0), left_laneline[select_mode].unsqueeze(0), right_laneline[select_mode].unsqueeze(0)], axis=-1)])
            output_left_types = torch.stack([output_left_type[select_mode].unsqueeze(0)])
            output_right_types = torch.stack([output_right_type[select_mode].unsqueeze(0)])
            hs_merge = hs[:, select_mode, :, :].unsqueeze(1)
            outs = {
                'all_cls_scores': outputs_classes,
                'all_lanes_preds': outputs_coords,
                'all_mask_preds': torch.stack([outputs_mask]) if self.pred_mask else None,
                'all_lanes_left_type': output_left_types,
                'all_lanes_right_type': output_right_types,
                'history_states': hs_merge,
                'sample_idx':img_metas[0]['sample_idx']
            }
        
        elif select_mode == 2:
            bs_mode, num_q, cls_lane = outputs_class.shape
            outputs_class_merge = outputs_class.reshape(bs_mode//2, num_q * 2, cls_lane)
            
            bs_mode, num_q, cls_lane = output_left_type.shape
            output_left_type_merge = output_left_type.reshape(bs_mode//2, num_q * 2, cls_lane)
            bs_mode, num_q, cls_lane = output_right_type.shape
            output_right_type_merge = output_right_type.reshape(bs_mode//2, num_q * 2, cls_lane)

            bs_mode, num_q, pts_lane = centerline.shape
            centerline_merge = centerline.reshape(bs_mode//2, num_q * 2, pts_lane)
            bs_mode, num_q, pts_lane = left_laneline.shape
            left_laneline_merge = left_laneline.reshape(bs_mode//2, num_q * 2, pts_lane)
            bs_mode, num_q, pts_lane = right_laneline.shape
            right_laneline_merge = right_laneline.reshape(bs_mode//2, num_q * 2, pts_lane)

            select_outputs_class_merge, _ = outputs_class_merge[-1].max(-1)
            topk_s, topk_i =  select_outputs_class_merge.topk(k=self.num_query, dim=-1)
            outputs_classes = torch.stack([outputs_class_merge])[:,:,topk_i,:]
            outputs_coords = torch.stack([torch.cat([centerline_merge, left_laneline_merge, right_laneline_merge], axis=-1)])[:,:,topk_i,:]
            output_left_types = torch.stack([output_left_type_merge])[:,:,topk_i,:]
            output_right_types = torch.stack([output_right_type_merge])[:,:,topk_i,:]
            layer_num, bs_mode, num_q, cdim = hs.shape
            hs_merge = hs.reshape(layer_num, bs_mode//2, num_q * 2, cdim)[:,:,topk_i,:]

      

            outs = {
                'all_cls_scores': outputs_classes,
                'all_lanes_preds': outputs_coords,
                'all_mask_preds': torch.stack([outputs_mask]) if self.pred_mask else None,
                'all_lanes_left_type': output_left_types,
                'all_lanes_right_type': output_right_types,
                'history_states': hs_merge,
                'sample_idx':img_metas[0]['sample_idx']
            }

 
        if self.streaming_query:
            query_list = []
            ref_pts_list = [] ##中心线reference pts
            lane_ref_pts_list = [] ##左右线ref pts



            stream_select_mode = 1 ## 时序用 0 单帧query结果 1 多帧query结果
            lines, scores = outputs_class[stream_select_mode].unsqueeze(0), outputs_class[stream_select_mode].unsqueeze(0)

            bs = len(img_metas)
            for i in range(bs):


                inter_queries = hs[:,stream_select_mode,:,:].unsqueeze(1)

                _lines = lines[i] #torch.Size([200, 90])
                _queries = inter_queries[-1][i]
                _scores = scores[i] #torch.Size([200, 2])


                _inter_reference = inter_references[:,stream_select_mode,:,:,:].unsqueeze(1)[:,i,:,:,:]
                _inter_lane_reference = inter_lanepts_references[:,stream_select_mode,:,:,:].unsqueeze(1)[:,i,:,:,:]
                assert len(_lines) == len(_queries)

                _scores, _ = _scores.max(-1)
                topk_score, topk_idx = _scores.topk(k=self.topk_query, dim=-1) #topk = 66

                _queries = _queries[topk_idx] # (topk, embed_dims)
                _lines = _lines[topk_idx] # (topk, 2*num_pts)



                _inter_reference = _inter_reference[:,topk_idx,:,:]#layer, topk, 10, 3
                _inter_lane_reference = _inter_lane_reference[:,topk_idx,:,:]#layer, topk, 8, 3
                _all_reference =  _inter_reference


        return outs


    def _get_target_single(self,
                           cls_score,
                           lanes_pred,
                           masks_pred,
                           lanes_left_type_preds,
                           lanes_right_type_preds,
                           gt_labels,
                           gt_lanes,
                           gt_instance_masks,
                           gt_lanes_left_type, 
                           gt_lanes_right_type,
                           gt_bboxes_ignore=None):

        pass

    def get_targets(self,
                    cls_scores_list,
                    lanes_preds_list,
                    masks_preds_list,
                    lanes_left_type_preds_list,
                    lanes_right_type_preds_list,
                    gt_lanes_list,
                    gt_labels_list,
                    gt_masks_list,
                    gt_lanes_left_type_list, 
                    gt_lanes_right_type_list,
                    gt_bboxes_ignore_list=None):

        pass
    
    def loss_single(self,
                    cls_scores,
                    lanes_preds,
                    masks_preds,
                    lanes_left_type, 
                    lanes_right_type,
                    gt_lanes_list,
                    gt_labels_list,
                    gt_masks_list,
                    gt_lanes_left_type, 
                    gt_lanes_right_type,
                    gt_bboxes_ignore_list=None):

        pass
    @force_fp32(apply_to=('preds_dicts'))
    def loss(self,
             preds_dicts,
             gt_lanes_3d,
             gt_labels_list,
             gt_instance_masks,
             gt_lane_left_type,
             gt_lane_right_type,
             loss_dict = None,
             gt_bboxes_ignore=None,
             img_metas=None):
        """"Loss function.
        Args:

            gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
                with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels_list (list[Tensor]): Ground truth class indices for each
                image with shape (num_gts, ).
            gt_instance_masks (list[Tensor]): Ground truth instance masks for each lane segment 
                of map size with shape (num_gts, 100, 50)
            preds_dicts:
                all_cls_scores (Tensor): Classification score of all
                    decoder layers, has shape
                    [nb_dec, bs, num_query, cls_out_channels].
                all_lanes_preds (Tensor): Sigmoid regression
                    outputs of all decode layers. Each is a 4D-tensor with
                    normalized coordinate format (cx, cy, w, h) and shape
                    [nb_dec, bs, num_query, 4].
                all_masks_preds (Tensor): Bitwise instance segmentation outputs of 
                    all decoder layers. Each is a bitwise segmentation map with shape (100,50).
                enc_cls_scores (Tensor): Classification scores of
                    points on encode feature map , has shape
                    (N, h*w, num_classes). Only be passed when as_two_stage is
                    True, otherwise is None.
                enc_bbox_preds (Tensor): Regression results of each points
                    on the encode feature map, has shape (N, h*w, 4). Only be
                    passed when as_two_stage is True, otherwise is None.
            gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
                which can be ignored for each image. Default None.
        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        pass
    
    @force_fp32(apply_to=('preds_dicts'))
    def get_lanes(self, preds_dicts, img_metas, rescale=False):
        """Generate bboxes from bbox head predictions.
        Args:
            preds_dicts (tuple[list[dict]]): Prediction results.
            img_metas (list[dict]): Point cloud and image's meta info.
        Returns:
            list[dict]: Decoded bbox, scores and labels after nms.
        """

        preds_dicts = self.bbox_coder.decode(preds_dicts)

        num_samples = len(preds_dicts)
        ret_list = []
        for i in range(num_samples):
            preds = preds_dicts[i]
            lanes = preds['lane3d']
            scores = preds['scores']
            labels = preds['labels']
            result = [lanes, scores, labels]
            if 'left_type_scores' in preds:
                left_type_scores = preds['left_type_scores']
                left_type_labels = preds['left_type_labels']
                right_type_scores = preds['right_type_scores']
                right_type_labels = preds['right_type_labels']
                result.extend([left_type_scores, left_type_labels, right_type_scores, right_type_labels])
            ret_list.append(result)

        return ret_list

    def train(self, *args, **kwargs):
        super().train(*args, **kwargs)
        for k, v in self.__dict__.items():
            if isinstance(v, StreamTensorMemory):
                v.train(*args, **kwargs)
    
    def eval(self):
        super().eval()
        for k, v in self.__dict__.items():
            if isinstance(v, StreamTensorMemory):
                v.eval()

    def forward(self, *args, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(*args, **kwargs)
        else:
            return self.forward_test(*args, **kwargs)

    def prepare_temporal_propagation(self, preds_dict, scene_name, local_idx, memory_bank=None, 
                        thr_track=0.1, thr_det=0.5):
        lines = preds_dict['all_lanes_preds'][0] # List[Tensor(num_queries, 2*num_points)]
        queries = preds_dict['history_states'][0]
        # masks = preds_dict['all_mask_preds']
    
        left_types = preds_dict['all_lanes_left_type'][0]
        right_types = preds_dict['all_lanes_right_type'][0]
        bs = len(lines)
        assert bs == 1, 'now only support bs=1 for temporal-evolving inference'
        scores = preds_dict['all_cls_scores'][0] # (bs, num_queries, 3)

        first_frame = local_idx == 0

        tmp_vectors = lines[0]
        tmp_queries = queries[0]
        # tmp_masks = masks[0]
        tmp_left_types = left_types[0]
        tmp_right_types = right_types[0]

        # focal loss
        if self.loss_cls.use_sigmoid:
            tmp_scores, tmp_labels = scores[0].max(-1)
            tmp_scores = tmp_scores.sigmoid()

            if not first_frame:
                
                last_track_num = self.prop_info['prop_frame_num_instance'] if self.prop_info['prop_frame_num_instance']<= self.topk_query else self.topk_query
                pos_track = tmp_scores[:last_track_num] > thr_track
                pos_det = tmp_scores[last_track_num:] > thr_det
                pos = torch.cat([pos_track, pos_det], dim=0)
                # print('pos_track:{},pos_det:{}'.format(torch.sum(pos_track),torch.sum(pos_det)))
            else:
                pos_det = tmp_scores > thr_det
                pos = pos_det
                # print('pos_track:{},pos_det:{}'.format('0',torch.sum(pos)))
        
        
        else:
            raise RuntimeError('The experiment uses sigmoid for cls outputs')

        if torch.sum(pos)> self.topk_query and not first_frame:
            _, valid_idx = torch.topk(tmp_scores[last_track_num:], k=self.topk_query - torch.sum(pos_track), dim=-1)
            pad_pos_det = torch.zeros_like(pos_det, dtype=torch.bool)
            pad_pos_det[valid_idx] = True
            pos_det = pad_pos_det
            pos = torch.cat([pos_track, pos_det], dim=0)
        elif torch.sum(pos)> self.topk_query and first_frame:
            _, valid_idx = torch.topk(tmp_scores, k=self.topk_query , dim=-1)
            pad_pos_det = torch.zeros_like(pos_det, dtype=torch.bool)
            pad_pos_det[valid_idx] = True
            pos_det = pad_pos_det
            pos = pos_det

        pos_vectors = tmp_vectors[pos]
        pos_labels = tmp_labels[pos]
        pos_queries = tmp_queries[pos]
        pos_scores = tmp_scores[pos]
        # pos_masks = tmp_masks[pos]
        _, pos_left_types = tmp_left_types[pos].sigmoid().max(-1)
        _, pos_right_types = tmp_right_types[pos].sigmoid().max(-1)

        if first_frame:
            global_ids = torch.arange(len(pos_vectors))
            num_instance = len(pos_vectors)

        else:
            
            prop_ids = self.prop_info['global_ids']
            prop_num_instance = self.prop_info['num_instance']
            prop_frame_num_instance = self.prop_info['prop_frame_num_instance']

            try:
                global_ids_track = prop_ids[pos_track[:len(prop_ids)]] ##出局掉多少上次的检测
            except:
                import pdb; pdb.set_trace()
            num_newborn = int(pos_det.sum()) ##放进来多少新的检测
            global_ids_newborn = torch.arange(num_newborn) + prop_num_instance
            # print('global_ids_track:{},new:{}'.format(len(global_ids_track), len(global_ids_newborn)))
            global_ids = torch.cat([global_ids_track, global_ids_newborn])
            # print('len(global_ids)',len(global_ids))
            num_instance = prop_num_instance + num_newborn

        if first_frame:
            self.prop_info = {
                'vectors': pos_vectors,
                'queries': pos_queries,
                'scores': pos_scores,
                'labels': pos_labels,
                # 'masks': pos_masks,
                'left_types': pos_left_types,
                'right_types': pos_right_types,
                'scene_name': scene_name,
                'local_idx': local_idx,
                'global_ids': global_ids,
                'num_instance': num_instance,
                'prop_frame_num_instance':num_instance,
            }
        else:
            self.prop_info = {
                'vectors': pos_vectors,
                'queries': pos_queries,
                'scores': pos_scores,
                'labels': pos_labels,
                # 'masks': pos_masks,
                'left_types': pos_left_types,
                'right_types': pos_right_types,
                'scene_name': scene_name,
                'local_idx': local_idx,
                'global_ids': global_ids,
                'num_instance': num_instance,
                'prop_frame_num_instance':len(pos_vectors),
            }

     
        if memory_bank is not None:
            if first_frame:
                num_tracks = 0
            else:
                num_tracks = self.prop_active_tracks
            pos_out_inds = torch.where(pos)[0]
            prev_out = {
                'history_states': queries,
                'all_cls_scores': scores,
            }
    
            memory_bank.update_memory(0, first_frame, pos_out_inds, prev_out, num_tracks, local_idx, memory_bank.curr_t)
            self.prop_active_tracks = len(pos_out_inds)
        
        save_pos_results = {
            'vectors': pos_vectors.cpu().numpy(),
            'scores': pos_scores.cpu().numpy(),
            'labels': pos_labels.cpu().numpy(),
            # 'masks': pos_masks.cpu().numpy(),
            'left_types': pos_left_types.cpu().numpy(),
            'right_types': pos_right_types.cpu().numpy(),
            'global_ids': global_ids.cpu().numpy(),
            'scene_name': scene_name,
            'local_idx': local_idx,
            'num_instance': num_instance,
        }

        return [save_pos_results]
