"""
                    NuerIPS 2025 submission
                      Anonymous Author(s)
We modified the code according to SSR and retain the copyright statement here.
====================================================================
 Copyright (c) Zhijia Technology. All rights reserved.
 
 Author: Peidong Li (lipeidong@smartxtruck.com / peidongl@outlook.com)
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
 
     http://www.apache.org/licenses/LICENSE-2.0
 
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ====================================================================
"""
import time
import copy

import torch
from einops import rearrange
from mmdet.models import DETECTORS
from mmdet3d.core import bbox3d2result
from mmcv.runner import force_fp32, auto_fp16
from scipy.optimize import linear_sum_assignment
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmdet3d.models.builder import build_loss
from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask
from projects.mmdet3d_plugin.SSR.utils import build_mlp, RequiresGrad, ValueHead, ValueHeadTransformer
from projects.mmdet3d_plugin.SSR.planner.metric_stp3 import PlanningMetric
from projects.mmdet3d_plugin.SSR.reward import IntrinsicRewardModel, CriticRewardModel, ImitationRewardModel
from .tokenlearner import TokenFuser
from collections import OrderedDict
import torch.nn.functional as F
import torch.distributed as dist
import torch.nn as nn

import os
import pickle
import numpy as np

@DETECTORS.register_module()
class SSR(MVXTwoStageDetector):
    """SSR model.
    """
    def __init__(self,
                 use_grid_mask=False,
                 pts_voxel_layer=None,
                 pts_voxel_encoder=None,
                 pts_middle_encoder=None,
                 pts_fusion_layer=None,
                 img_backbone=None,
                 pts_backbone=None,
                 img_neck=None,
                 pts_neck=None,
                 pts_bbox_head=None,
                 latent_world_model=None,
                 reward_head=None, 
                 img_roi_head=None,
                 img_rpn_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None,
                 video_test_mode=False,
                 fut_ts=6,
                 fut_mode=6,
                 loss_bev=None,
                 loss_latent_pos=None,
                 loss_latent_query=None,
                 world_model_temporal_augmentation_dict=None,
                 critic_head=None,
                 world_model_config=None,
                 actor_critic_config=None,
                 ):

        super(SSR,
              self).__init__(pts_voxel_layer, pts_voxel_encoder,
                             pts_middle_encoder, pts_fusion_layer,
                             img_backbone, pts_backbone, img_neck, pts_neck,
                             pts_bbox_head, img_roi_head, img_rpn_head,
                             train_cfg, test_cfg, pretrained)
        self.grid_mask = GridMask(
            True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
        self.use_grid_mask = use_grid_mask
        self.fp16_enabled = False
        self.fut_ts = fut_ts
        self.fut_mode = fut_mode
        self.valid_fut_ts = pts_bbox_head['valid_fut_ts']

        # temporal
        self.video_test_mode = video_test_mode
        self.prev_frame_info = {
            'prev_bev': None,
            'scene_token': None,
            'prev_pos': 0,
            'prev_angle': 0,
        }

        self.planning_metric = None
        self.embed_dims = 256
        self.latent_world_model = latent_world_model
        self.tokenfuser = TokenFuser(16, 256)

        self.intrinsic_reward_head = IntrinsicRewardModel(**reward_head['intrinsic'])
        self.critic_reward_head = CriticRewardModel(**reward_head['critic'])
        self.imitation_reward_head = ImitationRewardModel(**reward_head['imitation'])

        if self.latent_world_model is not None:
            latent_world_model_cfg = copy.deepcopy(self.latent_world_model)
            self.latent_world_model = build_transformer_layer_sequence(self.latent_world_model)
            self.latent_world_model_act_pos = build_transformer_layer_sequence(latent_world_model_cfg)
            for p in self.latent_world_model.parameters():
                if p.dim() > 1:
                    torch.nn.init.xavier_uniform_(p)
            for p in self.latent_world_model_act_pos.parameters():
                if p.dim() > 1:
                    torch.nn.init.xavier_uniform_(p)
            self.loss_bev = build_loss(loss_bev)
            self.loss_latent_pos = build_loss(loss_latent_pos)
            self.loss_latent_query=build_loss(loss_latent_query)

            self.world_model_config = world_model_config
            self.world_model_temporal_augmentation_dict = world_model_temporal_augmentation_dict
            self.actor_critic_config = actor_critic_config

            if actor_critic_config['enable_rl']:
                self.critic_head = ValueHeadTransformer(critic_head)
                if actor_critic_config['ema_regularization']:
                    self.slow_critic = copy.deepcopy(self.critic_head)

                self.actor_rl = copy.deepcopy(self.pts_bbox_head.actor)

                self.iter_counter = 0
                self.il_agent_score = None
                self.rl_agent_score = None
                self.num_il_agent_win = 0
                self.num_rl_agent_win = 0

                self.module_index_dict = self.get_modules_by_optimizer()

                self.finish_warmup_flag = False


    def extract_img_feat(self, img, img_metas, len_queue=None):
        """Extract features of images."""
        B = img.size(0)
        if img is not None:
            
            # input_shape = img.shape[-2:]
            # # update real input shape of each single img
            # for img_meta in img_metas:
            #     img_meta.update(input_shape=input_shape)

            if img.dim() == 5 and img.size(0) == 1:
                img.squeeze_()
            elif img.dim() == 5 and img.size(0) > 1:
                B, N, C, H, W = img.size()
                img = img.reshape(B * N, C, H, W)
            if self.use_grid_mask:
                img = self.grid_mask(img)

            img_feats = self.img_backbone(img)
            if isinstance(img_feats, dict):
                img_feats = list(img_feats.values())
        else:
            return None
        if self.with_img_neck:
            img_feats = self.img_neck(img_feats)

        img_feats_reshaped = []
        for img_feat in img_feats:
            BN, C, H, W = img_feat.size()
            if len_queue is not None:
                img_feats_reshaped.append(img_feat.view(int(B/len_queue), len_queue, int(BN / B), C, H, W))
            else:
                img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
        return img_feats_reshaped

    @auto_fp16(apply_to=('img'), out_fp32=True)
    def extract_feat(self, img, img_metas=None, len_queue=None):
        """Extract features from images and points."""

        img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
        
        return img_feats

    def forward_pts_train(self,
                          pts_feats,
                          gt_bboxes_3d,
                          gt_labels_3d,
                          map_gt_bboxes_3d,
                          map_gt_labels_3d,                          
                          img_metas,
                          gt_bboxes_ignore=None,
                          map_gt_bboxes_ignore=None,
                          prev_bev=None,
                          next_bev=None,
                          next_latent_pos=None,
                          next_latent_query=None,
                          ego_his_trajs=None,
                          ego_fut_trajs=None,
                          ego_fut_masks=None,
                          ego_fut_cmd=None,
                          ego_lcf_feat=None,
                          gt_attr_labels=None,
                          wm_img_metas=None,
                          wm_img_fut=None,
                          next_trajs=None,
                          wm_ego_fut_trajs=None,
                          ):
        """Forward function'
        Args:
            pts_feats (list[torch.Tensor]): Features of point cloud branch
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
                boxes for each sample.
            gt_labels_3d (list[torch.Tensor]): Ground truth labels for
                boxes of each sampole
            img_metas (list[dict]): Meta information of samples.
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                boxes to be ignored. Defaults to None.
            prev_bev (torch.Tensor, optional): BEV features of previous frame.
        Returns:
            dict: Losses of each branch.
        """

        outs = self.pts_bbox_head(pts_feats, img_metas, prev_bev,
                                  ego_his_trajs=ego_his_trajs, ego_lcf_feat=ego_lcf_feat, cmd=ego_fut_cmd)
        loss_inputs = [
            gt_bboxes_3d, gt_labels_3d, map_gt_bboxes_3d, map_gt_labels_3d,
            outs, ego_fut_trajs, ego_fut_masks, ego_fut_cmd, gt_attr_labels
        ]
        
        losses, rewards = self.pts_bbox_head.loss(*loss_inputs, img_metas=img_metas)

        if self.latent_world_model is not None:
            act_query = outs['act_query']
            act_pos = outs['act_pos']
            bev_embed = outs['bev_embed']

            pred_pos = self.latent_world_model_act_pos(
                query=act_pos,
                key=act_pos,
                value=act_pos
            )

            if self.world_model_config['use_future_pos_supervision']:
                loss_latent_pos = self.loss_latent_pos(pred_pos, next_latent_pos.detach())
                losses.update(loss_latent_pos=loss_latent_pos)

            pred_latent = self.latent_world_model(
                    query=act_query, # shape: (16, 1, 256)
                    key=act_query,
                    value=act_query,
                    query_pos=pred_pos,
                    key_pos=pred_pos    
                )

            # loss_latent_query = self.loss_latent_query(pred_latent, next_latent_query.detach())
            # losses.update(loss_latent_query=loss_latent_query)
            
            pred_bev = self.tokenfuser(pred_latent.permute(1, 0, 2), bev_embed)
            loss_bev = self.loss_bev(pred_bev, next_bev.detach())
            losses.update(loss_bev=loss_bev)


            # we here recurently training the world model
            if self.world_model_temporal_augmentation_dict['use_world_model_temporal_augmentation']:
                loss_wm_recurrent_prediction = 0
                beta = self.world_model_temporal_augmentation_dict['beta']
                alpha = self.world_model_temporal_augmentation_dict['alpha']
                if self.world_model_temporal_augmentation_dict['use_ssr_initialization']:
                    initial_wm_pred_latent_state = pred_latent
                    initial_wm_pred_bev = pred_bev
                else:
                    # the first rollout is almost the same with ssr
                    # the only difference is that we use the gt trajs
                    initial_wm_wp_vector = ego_fut_trajs.reshape(-1)
                    initial_wm_wp_vector = initial_wm_wp_vector.unsqueeze(0).unsqueeze(0) # (1, 1, 6*2)

                    initial_wm_act_query = self.pts_bbox_head.action_mln(outs['scene_query'], initial_wm_wp_vector)
                    initial_wm_pred_latent_state = self.latent_world_model(
                        query=initial_wm_act_query,
                        key=initial_wm_act_query,
                        value=initial_wm_act_query)
                    initial_wm_pred_bev = self.tokenfuser(initial_wm_pred_latent_state.permute(1, 0, 2), bev_embed)
                    # TODO maybe add another KL loss here KL(initial_wm_pred_latent_state || pred_latent)
                    # actually, the two term are all determined tensors, so use KL seems somewhat problematic
                    # instead, we can use the L2 loss to measure the distance between two tensors
                    # L2 is MSE
                    mse_loss = nn.MSELoss(reduction='mean')
                    loss_reinforce_action_wm = alpha * mse_loss(initial_wm_pred_latent_state, pred_latent.detach())
                    losses.update(loss_reinforce_action_wm=loss_reinforce_action_wm)

                loss_wm_recurrent_prediction += self.loss_bev(initial_wm_pred_bev, next_bev.detach())

                wm_len_queue = wm_img_fut.size(1)
                for i in range(wm_len_queue):
                    # we first compute the ground truth of BEV feature
                    # this is what the world model should predict
                    # TODO experiment, input previous BEV to generate the next BEV
                    wm_fut_img_i = wm_img_fut[:, i, ...] # (batch_size, 6, 3, 384, 640)
                    wm_fut_img_metas_i = [each[i] for each in wm_img_metas]
                    wm_fut_bev_i = self.obtain_next_bev(wm_fut_img_i, wm_fut_img_metas_i)

                    # then we compute the prediction of BEV feature
                    if i == 0:
                        wm_cur_action = next_trajs
                        wm_cur_latent_state = initial_wm_pred_latent_state
                        wm_cur_bev = initial_wm_pred_bev
                    else:
                        wm_cur_action = wm_ego_fut_trajs[:, i-1, ...]
                        wm_cur_latent_state = wm_pred_latent_state
                        wm_cur_bev = wm_pred_bev

                    wm_cur_action = wm_cur_action.reshape(-1)
                    wm_cur_action = wm_cur_action.unsqueeze(0).unsqueeze(0) # (1, 1, 6*2)
                    wm_cur_act_query = self.pts_bbox_head.action_mln(wm_cur_latent_state, wm_cur_action)
                    wm_pred_latent_state = self.latent_world_model(
                        query=wm_cur_act_query,
                        key=wm_cur_act_query,
                        value=wm_cur_act_query)

                    wm_pred_bev = self.tokenfuser(wm_pred_latent_state.permute(1, 0, 2), wm_cur_bev)

                    loss_wm_recurrent_prediction += beta**i * self.loss_bev(wm_pred_bev, wm_fut_bev_i.detach())

                losses.update(loss_wm_recurrent_prediction=loss_wm_recurrent_prediction)
            else:
                wm_len_queue = wm_img_fut.size(1)
                assert wm_len_queue == 0, 'when not using recurrent world model, the length of world model should be 0'
            
            # reward modeling
            # return reward distribution
            imitation_reward = self.imitation_reward_head(state_tokens=outs["scene_query"].detach(), action=outs["ego_fut_preds_sample"].detach())
            critic_reward = self.critic_reward_head(state_tokens=outs["scene_query"].detach(), action=outs["ego_fut_preds_sample"].detach(), next_state_tokens=pred_latent.detach())            
            intrinsic_reward = self.intrinsic_reward_head(outs['scene_query'].detach())
            
            imitation_reward_gt = rewards['imitation_reward']
            critic_reward_gt = rewards['critic_reward']
            intrinsic_reward_gt = rewards['intrinsic_reward']
            
            loss_reward_imitation = -imitation_reward.log_prob(imitation_reward_gt.detach())
            loss_reward_critic = -critic_reward.log_prob(critic_reward_gt.detach()) #(B,)
            loss_reward_intrinsic = -intrinsic_reward.log_prob(intrinsic_reward_gt.detach()) #(B,)
            
            loss_reward_imitation = 0.1 * loss_reward_imitation
            loss_reward_critic = 0.1 * loss_reward_critic
            loss_reward_intrinsic = 0.05 * loss_reward_intrinsic

            losses.update(loss_reward_imitation=loss_reward_imitation.mean())
            losses.update(loss_reward_critic=loss_reward_critic.mean())
            losses.update(loss_reward_intrinsic=loss_reward_intrinsic.mean())

            if "critic_reward_augment" in rewards.keys():
                # if we need to augment the critic reward
                # we need to do extra supervision for the critic reward head
                with torch.no_grad():
                    act_query_reward_aug = outs['act_query_reward_aug'].detach() # [16, B*24 or B*25, D]
                    act_pos_reward_aug = outs['act_pos_reward_aug'].detach() # [16, B*24 or B*25, D]

                    pred_pos_reward_aug = self.latent_world_model_act_pos(
                        query=act_pos_reward_aug, # [16, B*24 or B*25, D]
                        key=act_pos_reward_aug,   # [16, B*24 or B*25, D]
                        value=act_pos_reward_aug  # [16, B*24 or B*25, D]
                    )

                    pred_latent_reward_aug = self.latent_world_model(
                            query=act_query_reward_aug,    # [16, B*24 or B*25, D]
                            key=act_query_reward_aug,      # [16, B*24 or B*25, D]
                            value=act_query_reward_aug,    # [16, B*24 or B*25, D]
                            query_pos=pred_pos_reward_aug, # [16, B*24 or B*25, D]
                            key_pos=pred_pos_reward_aug    # [16, B*24 or B*25, D]
                    )

                # [B, 24, T, 2] => [B*24, T, 2]
                b, n_p, _, _ = outs["explore_trajectory"].shape
                action_reward_aug = rearrange(outs["explore_trajectory"], 'b n_p t xy -> (b n_p) t xy', b=b, n_p=n_p)
                # [16, B, 256] => [16, B, 1, 256] => [16, B, 25, 256]
                scene_query_expand = outs["scene_query"].unsqueeze(2).expand(-1, -1, n_p, -1)
                scene_query_expand = rearrange(scene_query_expand, 'n_t b n_p d -> n_t (b n_p) d', b=b, n_p=n_p)
                
                imitation_reward_aug = self.imitation_reward_head(state_tokens=scene_query_expand.detach(), action=action_reward_aug.detach()) # action [B, 24, T, 2]
                imitation_reward_aug_gt = rewards["imitation_reward_augment_gt"] # [b*n_p]
                
                critic_reward_aug = self.critic_reward_head(state_tokens=scene_query_expand.detach(), action=action_reward_aug.detach(), next_state_tokens=pred_latent_reward_aug.detach()) # [B,]
                critic_reward_aug_gt = rewards['critic_reward_augment']

                loss_reward_imitation_aug = -imitation_reward_aug.log_prob(imitation_reward_aug_gt.detach())
                loss_reward_critic_aug = -critic_reward_aug.log_prob(critic_reward_aug_gt.detach()) #(B,)

                loss_reward_imitation_aug = 0.1 * loss_reward_imitation_aug
                loss_reward_critic_aug = 0.1 * loss_reward_critic_aug

                losses.update(loss_reward_imitation_aug=loss_reward_imitation_aug.mean())
                losses.update(loss_reward_critic_aug=loss_reward_critic_aug.mean())
        # TODO
        # here, we need to return the scene token, because the actor-critic stage need an initial stage to imagine
        return losses, (outs['scene_query'], outs['scene_pos'])
        # return losses

    def forward_dummy(self, img):
        dummy_metas = None
        return self.forward_test(img=img, img_metas=[[dummy_metas]])

    def forward(self, return_loss=True, **kwargs):
        """Calls either forward_train or forward_test depending on whether
        return_loss=True.
        Note this setting will change the expected inputs. When
        `return_loss=True`, img and img_metas are single-nested (i.e.
        torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
        img_metas should be double nested (i.e.  list[torch.Tensor],
        list[list[dict]]), with the outer list indicating test time
        augmentations.
        """
        if return_loss:
            return self.forward_train(**kwargs)
        else:
            return self.forward_test(**kwargs)
    
    def obtain_history_bev(self, imgs_queue, img_metas_list, prev_cmd):
        """Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
        """
        self.eval()

        with torch.no_grad():
            prev_bev = None
            bs, len_queue, num_cams, C, H, W = imgs_queue.shape
            imgs_queue = imgs_queue.reshape(bs*len_queue, num_cams, C, H, W)
            img_feats_list = self.extract_feat(img=imgs_queue, len_queue=len_queue)
            for i in range(len_queue):
                img_metas = [each[i] for each in img_metas_list]
                img_feats = [each_scale[:, i] for each_scale in img_feats_list]
                cmd = prev_cmd[:, i, ...]
                prev_bev = self.pts_bbox_head(
                    img_feats, img_metas, prev_bev=prev_bev, only_bev=True, cmd=cmd)
            self.train()
            return prev_bev
    
    def obtain_next_bev(self, img, img_metas, next_cmd):
        """Obtain future BEV features.
        """
        self.eval()
        with torch.no_grad():
            img_feats = self.extract_feat(img=img, img_metas=img_metas)
            next_bev, next_latent_pos, next_latent_query = self.pts_bbox_head(
                    img_feats, img_metas, only_bev=True, also_latent_state=True, cmd=next_cmd)
            self.train()
            return next_bev, next_latent_pos, next_latent_query

    # @auto_fp16(apply_to=('img', 'points'))
    @force_fp32(apply_to=('img','points','prev_bev'))
    def forward_train(self,
                      points=None,
                      img_metas=None,
                      gt_bboxes_3d=None,
                      gt_labels_3d=None,
                      map_gt_bboxes_3d=None,
                      map_gt_labels_3d=None,
                      gt_labels=None,
                      gt_bboxes=None,
                      img=None,
                      proposals=None,
                      gt_bboxes_ignore=None,
                      map_gt_bboxes_ignore=None,
                      img_depth=None,
                      img_mask=None,
                      ego_his_trajs=None,
                      ego_fut_trajs=None,
                      ego_fut_masks=None,
                      ego_fut_cmd=None,
                      ego_lcf_feat=None,
                      gt_attr_labels=None,
                      **kwargs
                      ):
        """Forward training function.
        Args:
            points (list[torch.Tensor], optional): Points of each sample.
                Defaults to None.
            img_metas (list[dict], optional): Meta information of each sample.
                Defaults to None.
            gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
                Ground truth 3D boxes. Defaults to None.
            gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
                of 3D boxes. Defaults to None.
            gt_labels (list[torch.Tensor], optional): Ground truth labels
                of 2D boxes in images. Defaults to None.
            gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
                images. Defaults to None.
            img (torch.Tensor optional): Images of each sample with shape
                (N, C, H, W). Defaults to None.
            proposals ([list[torch.Tensor], optional): Predicted proposals
                used for training Fast RCNN. Defaults to None.
            gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
                2D boxes in images to be ignored. Defaults to None.
        Returns:
            dict: Losses of different branches.
        """
        # TODO check the data length (if we do not add data for wm)
        # whether or not the length of data is 4 (I am not sure, maybe 3)
        assert img.size(1) >= 4, 'the length of data should be bigger than 4'
        # if we change the dataset, and returned more future frames, we need to change the code here
        # in original version of ssr, len_queue = 4, but now, len_queue may bigger than 4
        # world model things should be saved at first
        n_frames = img.size(1)
        wm_img_fut = img[:, 4:, ...]
        img = img[:, :4, ...] # (batch_size, 4, 6, 3, 384, 640)

        wm_img_metas = []
        img_metas_temp = []
        for img_meta in img_metas:
            wm_img_dict = {}
            for wm_img_id in range(n_frames - 4):
                wm_img_dict[wm_img_id] = copy.deepcopy(img_meta[wm_img_id])
            wm_img_metas.append(wm_img_dict)

            img_metas_temp.append({0:img_meta[0], 1:img_meta[1], 2:img_meta[2], 3:img_meta[3]})
        
        img_metas = img_metas_temp

        # world model do not need fut_cmd
        ego_fut_cmd = ego_fut_cmd[:, :4, ...]

        wm_ego_fut_trajs = ego_fut_trajs[:, 4:, ...]
        ego_fut_trajs = ego_fut_trajs[:, :4, ...]

        # world model do not need fut_masks
        ego_fut_masks = ego_fut_masks[:, :4, ...]


        len_queue = img.size(1)
        prev_img = img[:, :-2, ...] # (batch_size, 4, 6, 3, 384, 640)
        next_img = img[:, -1, ...]
        img = img[:, -2, ...]
        prev_cmd = ego_fut_cmd[:, :-2, ...]
        next_cmd = ego_fut_cmd[:, -1, ...]
        ego_fut_cmd = ego_fut_cmd[:, -2, ...]

        prev_trajs = ego_fut_trajs[:, :-2, ...]
        next_trajs = ego_fut_trajs[:, -1, ...] # if we want to train the world model recurrently, we need this one
        ego_fut_trajs = ego_fut_trajs[:, -2, ...]

        prev_masks = ego_fut_masks[:, :-2, ...]
        next_masks = ego_fut_masks[:, -1, ...]
        ego_fut_masks = ego_fut_masks[:, -2, ...]

        prev_img_metas = copy.deepcopy(img_metas)
        # next_img_metas = copy.deepcopy(img_metas)
        next_img_metas = [each[len_queue-1] for each in img_metas]

        # TODO currently, the author first generate the BEV of previous (2 frames)
        # then directly generate the nex bev (only 1 frame)
        # when generate BEV feature, BEVFormer can use the previous BEV feature
        # to help generate the next BEV feature, but the author does not use it.
        # we should do experiment here to valid it
        prev_bev = self.obtain_history_bev(prev_img, prev_img_metas, prev_cmd) if len_queue > 1 else None
        next_bev, next_latent_pos, next_latent_query = self.obtain_next_bev(next_img, next_img_metas, next_cmd)

        img_metas = [each[len_queue-2] for each in img_metas]
        img_feats = self.extract_feat(img=img, img_metas=img_metas)
        losses = dict()
        losses_pts, initial_state_for_world_model = self.forward_pts_train(img_feats, gt_bboxes_3d, gt_labels_3d,
                                            map_gt_bboxes_3d, map_gt_labels_3d, img_metas,
                                            gt_bboxes_ignore, map_gt_bboxes_ignore, prev_bev, next_bev, next_latent_pos, next_latent_query,
                                            ego_his_trajs=ego_his_trajs, ego_fut_trajs=ego_fut_trajs,
                                            ego_fut_masks=ego_fut_masks, ego_fut_cmd=ego_fut_cmd,
                                            ego_lcf_feat=ego_lcf_feat, gt_attr_labels=gt_attr_labels,
                                            wm_img_metas=wm_img_metas, wm_img_fut=wm_img_fut, 
                                            next_trajs=next_trajs, # this is the gt_action of t+1 frame, 
                                            wm_ego_fut_trajs=wm_ego_fut_trajs) # this is the gt_action begins from t+2 frame
        losses.update(losses_pts)
        return losses, initial_state_for_world_model
        # return losses

    def forward_test(
        self,
        img_metas,
        gt_bboxes_3d,
        gt_labels_3d,
        map_gt_bboxes_3d,
        map_gt_labels_3d,
        img=None,
        ego_his_trajs=None,
        ego_fut_trajs=None,
        ego_fut_cmd=None,
        ego_lcf_feat=None,
        gt_attr_labels=None,
        return_wm_bev=None,
        frame_idx=None,
        **kwargs
    ):
        assert return_wm_bev is not None, 'return_wm_bev is None'
        for var, name in [(img_metas, 'img_metas')]:
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))
        img = [img] if img is None else img

        if img_metas[0][0]['scene_token'] != self.prev_frame_info['scene_token']:
            # the first sample of each scene is truncated
            self.prev_frame_info['prev_bev'] = None
        # update idx
        self.prev_frame_info['scene_token'] = img_metas[0][0]['scene_token']

        # do not use temporal information
        if not self.video_test_mode:
            self.prev_frame_info['prev_bev'] = None

        # Get the delta of ego position and angle between two timestamps.
        tmp_pos = copy.deepcopy(img_metas[0][0]['can_bus'][:3])
        tmp_angle = copy.deepcopy(img_metas[0][0]['can_bus'][-1])
        if self.prev_frame_info['prev_bev'] is not None:
            img_metas[0][0]['can_bus'][:3] -= self.prev_frame_info['prev_pos']
            img_metas[0][0]['can_bus'][-1] -= self.prev_frame_info['prev_angle']
        else:
            img_metas[0][0]['can_bus'][-1] = 0
            img_metas[0][0]['can_bus'][:3] = 0

        new_prev_bev, bbox_results = self.simple_test(
            img_metas=img_metas[0],
            img=img[0],
            prev_bev=self.prev_frame_info['prev_bev'],
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels_3d=gt_labels_3d,
            map_gt_bboxes_3d=map_gt_bboxes_3d,
            map_gt_labels_3d=map_gt_labels_3d,
            ego_his_trajs=ego_his_trajs[0],
            ego_fut_trajs=ego_fut_trajs[0],
            ego_fut_cmd=ego_fut_cmd[0],
            ego_lcf_feat=ego_lcf_feat[0],
            gt_attr_labels=gt_attr_labels,
            return_wm_bev=return_wm_bev,
            **kwargs
        )
        # During inference, we save the BEV features and ego motion of each timestamp.
        self.prev_frame_info['prev_pos'] = tmp_pos
        if not return_wm_bev[0]:
            self.prev_frame_info['prev_bev'] = new_prev_bev
        else:
            self.prev_frame_info['prev_bev'] = new_prev_bev[0]
            assert len(bbox_results) == 1, 'batch_size is not 1'
            # Now, we do not save these information into result.pkl
            # we save them into a directory
            scene_token = img_metas[0][0]['scene_token']
            scene_frame_idx = frame_idx[0].item()
            save_dir = "/opt/nvme1/<user-name>/projects/world-model-RL-store/SSR/vis_explore"
            file_name = f"scene_{scene_token}+frame_{scene_frame_idx}.pkl"

            # Convert LineString objects to lists of coordinates
            map_gt_lines_serializable = [
                np.array(line.coords) for line in map_gt_bboxes_3d[0].instance_list
            ]

            result = {
                'current_bev_embed': new_prev_bev[0].cpu().numpy(),                             # [1, 100*100, 256]
                'world_model_pred_bev_embed': new_prev_bev[1].cpu().numpy(),                    # [1, 100*100, 256]
                'current_selected_attention_map': new_prev_bev[2].cpu().numpy(),                # [1, 16, 100*100]
                'predict_selected_attention_map': new_prev_bev[3].cpu().numpy(),                # [1, 16, 100*100]
                'explore_trajs': new_prev_bev[4].cpu().numpy(),                                 # [B, num_of_exporation, T, 2]
                'predict_explore_bev_embed': new_prev_bev[5].cpu().numpy(),                     # [B, num_of_exporation, 100*100, 256]
                'predict_explore_selected_attention_map': new_prev_bev[6].cpu().numpy(),        # [B, num_of_exporation, 16, 100*100]
                'agent_gt_bboxes': gt_bboxes_3d[0][0].tensor.cpu().numpy(),                     # [num_of_agents, 9] 9=x,y,z,yaw....
                'agent_gt_labels': gt_labels_3d[0][0].cpu().numpy(),                            # [num_of_agents]
                'map_gt_lines': map_gt_lines_serializable,                                      # List[np.ndarray] [num_of_points_in_line, 2]
                'map_gt_labels': map_gt_labels_3d[0].cpu().numpy(),                             # [num_of_lines]
                'pred_trajectory': bbox_results[0]['pts_bbox']['ego_fut_preds'].cpu().numpy(),  # [3, 6, 2]
                'gt_command': bbox_results[0]['pts_bbox']['ego_fut_cmd'][0,0,0].cpu().numpy(),  # [3,] Right | Left | Straight
                'gt_trajectory': ego_fut_trajs[0][0,0].cpu().numpy(),                           # [6, 2]
            }
            # bbox_results[0]['pts_bbox']['current_bev_embed'] = new_prev_bev[0]
            # bbox_results[0]['pts_bbox']['world_model_pred_bev_embed'] = new_prev_bev[1]
            # bbox_results[0]['pts_bbox']['current_selected_attention_map'] = new_prev_bev[2]
            os.makedirs(save_dir, exist_ok=True)
            # Construct the full file path
            file_path = os.path.join(save_dir, file_name)

            # Save the result dictionary to a pickle file
            with open(file_path, 'wb') as f:
                pickle.dump(result, f)

        self.prev_frame_info['prev_angle'] = tmp_angle

        return bbox_results

    def simple_test(
        self,
        img_metas,
        gt_bboxes_3d,
        gt_labels_3d,
        map_gt_bboxes_3d,
        map_gt_labels_3d,
        img=None,
        prev_bev=None,
        points=None,
        fut_valid_flag=None,
        rescale=False,
        ego_his_trajs=None,
        ego_fut_trajs=None,
        ego_fut_cmd=None,
        ego_lcf_feat=None,
        gt_attr_labels=None,
        return_wm_bev=None,
        **kwargs
    ):
        """Test function without augmentaiton."""
        img_feats = self.extract_feat(img=img, img_metas=img_metas)
        bbox_list = [dict() for i in range(len(img_metas))]
        new_prev_bev, bbox_pts, metric_dict = self.simple_test_pts(
            img_feats,
            img_metas,
            gt_bboxes_3d,
            gt_labels_3d,
            map_gt_bboxes_3d,
            map_gt_labels_3d,
            prev_bev,
            fut_valid_flag=fut_valid_flag,
            rescale=rescale,
            start=None,
            ego_his_trajs=ego_his_trajs,
            ego_fut_trajs=ego_fut_trajs,
            ego_fut_cmd=ego_fut_cmd,
            ego_lcf_feat=ego_lcf_feat,
            gt_attr_labels=gt_attr_labels,
            return_wm_bev=return_wm_bev
        )
        for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
            result_dict['pts_bbox'] = pts_bbox
            result_dict['metric_results'] = metric_dict

        return new_prev_bev, bbox_list

    def simple_test_pts(
        self,
        x,
        img_metas,
        gt_bboxes_3d,
        gt_labels_3d,
        map_gt_bboxes_3d,
        map_gt_labels_3d,
        prev_bev=None,
        fut_valid_flag=None,
        rescale=False,
        start=None,
        ego_his_trajs=None,
        ego_fut_trajs=None,
        ego_fut_cmd=None,
        ego_lcf_feat=None,
        gt_attr_labels=None,
        return_wm_bev=None,
    ):
        """Test function"""
        mapped_class_names = [
            'car', 'truck', 'construction_vehicle', 'bus',
            'trailer', 'barrier', 'motorcycle', 'bicycle', 
            'pedestrian', 'traffic_cone'
        ]

        outs = self.pts_bbox_head(x, img_metas, prev_bev=prev_bev, cmd=ego_fut_cmd,
                                  ego_his_trajs=ego_his_trajs, ego_lcf_feat=ego_lcf_feat)

        bbox_results = []
        for i in range(len(outs['ego_fut_preds'])):
            bbox_result=dict()
            bbox_result['ego_fut_preds'] = outs['ego_fut_preds'][i].cpu()
            bbox_result['ego_fut_cmd'] = ego_fut_cmd.cpu()
            bbox_results.append(bbox_result)

        # use sampled policy instead of mean vector
        # for i in range(len(outs['ego_fut_preds'])):
        #     bbox_result=dict()
        #     bbox_result['ego_fut_preds'] = outs['ego_fut_preds_sample'][i].cpu()
        #     bbox_result['ego_fut_cmd'] = ego_fut_cmd.cpu()
        #     bbox_results.append(bbox_result)

        assert len(bbox_results) == 1, 'only support batch_size=1 now'
        # score_threshold = 0.6
        with torch.no_grad():
            gt_bbox = gt_bboxes_3d[0][0]
            gt_map_bbox = map_gt_bboxes_3d[0]
            gt_label = gt_labels_3d[0][0].to('cpu')
            gt_map_label = map_gt_labels_3d[0].to('cpu')
            gt_attr_label = gt_attr_labels[0][0].to('cpu')
            fut_valid_flag = bool(fut_valid_flag[0][0])
      
            metric_dict={}
            # ego planning metric
            assert ego_fut_trajs.shape[0] == 1, 'only support batch_size=1 for testing'
            ego_fut_preds = bbox_result['ego_fut_preds']
            ego_fut_trajs = ego_fut_trajs[0, 0]

            ego_fut_cmd = ego_fut_cmd[0, 0, 0]
            ego_fut_cmd_idx = torch.nonzero(ego_fut_cmd)[0, 0]

            ego_fut_pred = ego_fut_preds[ego_fut_cmd_idx]
            ego_fut_pred = ego_fut_pred.cumsum(dim=-2)
            ego_fut_trajs = ego_fut_trajs.cumsum(dim=-2)

            metric_dict_planner_stp3 = self.compute_planner_metric_stp3(
                pred_ego_fut_trajs = ego_fut_pred[None],
                gt_ego_fut_trajs = ego_fut_trajs[None],
                gt_agent_boxes = gt_bbox,
                gt_agent_feats = gt_attr_label.unsqueeze(0),
                gt_map_boxes = gt_map_bbox,
                gt_map_labels = gt_map_label,
                fut_valid_flag = fut_valid_flag
            )
            metric_dict.update(metric_dict_planner_stp3)
        if not return_wm_bev[0]:
            return outs['bev_embed'], bbox_results, metric_dict
        else:
            act_query = outs['act_query']
            act_pos = outs['act_pos']
            bev_embed = outs['bev_embed']

            pred_pos = self.latent_world_model_act_pos(
                query=act_pos,
                key=act_pos,
                value=act_pos
            )

            pred_latent = self.latent_world_model(
                    query=act_query, # shape: (16, 1, 256)
                    key=act_query,
                    value=act_query,
                    query_pos=pred_pos,
                    key_pos=pred_pos    
                )
            
            pred_bev = self.tokenfuser(pred_latent.permute(1, 0, 2), bev_embed)
            pred_selected = self.pts_bbox_head.get_bev_selected(pred_bev, ego_fut_cmd_idx)

            policy = outs['explore_policy']
            latent_query = outs['scene_query']
            latent_pos = outs['scene_pos']
            explore_trajs, explore_act_query, explore_act_pos = self.pts_bbox_head.get_exploration_data(latent_query, latent_pos, policy)

            pred_explore_pos = self.latent_world_model_act_pos(
                query=explore_act_pos,
                key=explore_act_pos,
                value=explore_act_pos
            )

            pred_explore_latent = self.latent_world_model(
                    query=explore_act_query, # shape: (16, B*N, 256)
                    key=explore_act_query,
                    value=explore_act_query,
                    query_pos=pred_explore_pos,
                    key_pos=pred_explore_pos    
            )
            B = bev_embed.size(0)
            N = pred_explore_latent.size(1) // B

            pred_explore_bev = self.tokenfuser(pred_explore_latent.permute(1, 0, 2), bev_embed.expand(N, -1, -1))
            pred_explore_selected = self.pts_bbox_head.get_bev_selected(pred_explore_bev, ego_fut_cmd_idx)

            pred_explore_bev = rearrange(pred_explore_bev, '(b n) bev_grid bev_feat -> b n bev_grid bev_feat', b=B, n=N)
            pred_explore_selected = rearrange(pred_explore_selected, '(b n) n_map bev_grid -> b n n_map bev_grid', b=B, n=N)

            return (outs['bev_embed'], pred_bev, outs['selected'], pred_selected, explore_trajs, pred_explore_bev, pred_explore_selected), bbox_results, metric_dict

    def map_pred2result(self, bboxes, scores, labels, pts, attrs=None):
        """Convert detection results to a list of numpy arrays.

        Args:
            bboxes (torch.Tensor): Bounding boxes with shape of (n, 5).
            labels (torch.Tensor): Labels with shape of (n, ).
            scores (torch.Tensor): Scores with shape of (n, ).
            attrs (torch.Tensor, optional): Attributes with shape of (n, ). \
                Defaults to None.

        Returns:
            dict[str, torch.Tensor]: Bounding box results in cpu mode.

                - boxes_3d (torch.Tensor): 3D boxes.
                - scores (torch.Tensor): Prediction scores.
                - labels_3d (torch.Tensor): Box labels.
                - attrs_3d (torch.Tensor, optional): Box attributes.
        """
        result_dict = dict(
            map_boxes_3d=bboxes.to('cpu'),
            map_scores_3d=scores.cpu(),
            map_labels_3d=labels.cpu(),
            map_pts_3d=pts.to('cpu'))

        if attrs is not None:
            result_dict['map_attrs_3d'] = attrs.cpu()

        return result_dict

    ### same planning metric as stp3
    def compute_planner_metric_stp3(
        self,
        pred_ego_fut_trajs,
        gt_ego_fut_trajs,
        gt_agent_boxes,
        gt_agent_feats,
        gt_map_boxes,
        gt_map_labels,
        fut_valid_flag
    ):
        """Compute planner metric for one sample same as stp3."""
        metric_dict = {
            'plan_L2_1s':0,
            'plan_L2_2s':0,
            'plan_L2_3s':0,
            'plan_obj_col_1s':0,
            'plan_obj_col_2s':0,
            'plan_obj_col_3s':0,
            'plan_obj_box_col_1s':0,
            'plan_obj_box_col_2s':0,
            'plan_obj_box_col_3s':0,
            # 'plan_obj_col_plus_1s':0,
            # 'plan_obj_col_plus_2s':0,
            # 'plan_obj_col_plus_3s':0,
            # 'plan_obj_box_col_plus_1s':0,
            # 'plan_obj_box_col_plus_2s':0,
            # 'plan_obj_box_col_plus_3s':0,
        }
        metric_dict['fut_valid_flag'] = fut_valid_flag
        future_second = 3
        assert pred_ego_fut_trajs.shape[0] == 1, 'only support bs=1'
        if self.planning_metric is None:
            self.planning_metric = PlanningMetric()
        segmentation, pedestrian, segmentation_plus = self.planning_metric.get_label(
            gt_agent_boxes, gt_agent_feats, gt_map_boxes, gt_map_labels)
        occupancy = torch.logical_or(segmentation, pedestrian)

        for i in range(future_second):
            if fut_valid_flag:
                cur_time = (i+1)*2
                traj_L2 = self.planning_metric.compute_L2(
                    pred_ego_fut_trajs[0, :cur_time].detach().to(gt_ego_fut_trajs.device),
                    gt_ego_fut_trajs[0, :cur_time]
                )
                traj_L2_stp3 = self.planning_metric.compute_L2_stp3(
                    pred_ego_fut_trajs[0, :cur_time].detach().to(gt_ego_fut_trajs.device),
                    gt_ego_fut_trajs[0, :cur_time]
                )
                obj_coll, obj_box_coll = self.planning_metric.evaluate_coll(
                    pred_ego_fut_trajs[:, :cur_time].detach(),
                    gt_ego_fut_trajs[:, :cur_time],
                    occupancy)
                # obj_coll_plus, obj_box_coll_plus = self.planning_metric.evaluate_coll(
                #     pred_ego_fut_trajs[:, :cur_time].detach(),
                #     gt_ego_fut_trajs[:, :cur_time],
                #     segmentation_plus)
                metric_dict['plan_L2_{}s'.format(i+1)] = traj_L2
                metric_dict['plan_obj_col_{}s'.format(i + 1)] = obj_coll.mean().item()
                metric_dict['plan_obj_box_col_{}s'.format(i + 1)] = obj_box_coll.mean().item()
                # metric_dict['plan_obj_col_plus_{}s'.format(i + 1)] = obj_coll_plus.mean().item()
                # metric_dict['plan_obj_box_col_plus_{}s'.format(i + 1)] = obj_box_coll_plus.mean().item()
                metric_dict['plan_L2_stp3_{}s'.format(i+1)] = traj_L2_stp3
                metric_dict['plan_obj_col_stp3_{}s'.format(i + 1)] = obj_coll[-1].item()
                metric_dict['plan_obj_box_col_stp3_{}s'.format(i + 1)] = obj_box_coll[-1].item()
                # metric_dict['plan_obj_col_stp3_plus_{}s'.format(i + 1)] = obj_coll_plus[-1].item()
                # metric_dict['plan_obj_box_col_stp3_plus_{}s'.format(i + 1)] = obj_box_coll_plus[-1].item()
                # if (i == 0):
                #     metric_dict['plan_1'] = obj_box_coll[0].item()
                #     metric_dict['plan_2'] = obj_box_coll[1].item()
                # if (i == 1):
                #     metric_dict['plan_3'] = obj_box_coll[2].item()
                #     metric_dict['plan_4'] = obj_box_coll[3].item()
                # if (i == 2):
                #     metric_dict['plan_5'] = obj_box_coll[4].item()
                #     metric_dict['plan_6'] = obj_box_coll[5].item()
            else:
                metric_dict['plan_L2_{}s'.format(i+1)] = 0.0
                metric_dict['plan_obj_col_{}s'.format(i+1)] = 0.0
                metric_dict['plan_obj_box_col_{}s'.format(i+1)] = 0.0
                metric_dict['plan_L2_stp3_{}s'.format(i + 1)] = 0.0
            
        return metric_dict

    def set_epoch(self, epoch): 
        self.pts_bbox_head.epoch = epoch

# =======
    def imagine_fut_state(self, action_query, action_pos):
        pred_pos = self.latent_world_model_act_pos(
            query=action_pos,
            key=action_pos,
            value=action_pos
        )
        pred_latent = self.latent_world_model(
                query=action_query, # shape: (16, 1, 256)
                key=action_query,
                value=action_query,
                query_pos=pred_pos,
                key_pos=pred_pos
            )
        return (pred_latent, pred_pos)
    
    def get_intrinsic_reward(self, state):
        intrinsic_reward = self.intrinsic_reward_head(state)
        return intrinsic_reward.mean # [B,]

    def get_critic_reward(self, root_node):
        # TODO check device
        # obtain a tensor with shape (B, 3)
        critic_reward_list = []
        for node in ['left_node', 'straight_node', 'right_node']:
            state = root_node[node]['state'][0].detach()
            critic_reward = self.critic_reward_head(state) #[B,]
            critic_reward_list.append(critic_reward.mean)
        
        # [B,] => [B, 3]
        critic_reward_batch = torch.stack(critic_reward_list, dim=-1)

        root_node['critic_reward'] = critic_reward_batch
        return

    def get_value(self, state):
        # detach the state?
        return self.critic_head(state).mean

    def get_slow_target(self, state):
        return self.slow_critic(state).mean

    def get_value_distribution(self, state):
        return self.critic_head(state)

    def reinforcement_learning_actor_planning(self, latent_state, latent_pos):
        return self.actor_rl.planning(latent_state, latent_pos)

    def imitation_learning_actor_planning(self, latent_state, latent_pos):
        return self.pts_bbox_head.actor.planning(latent_state, latent_pos)

    def update_slow_critic(self):
        ema_decay = self.actor_critic_config['ema_decay']
        # Iterate over the parameters of both critic_head and slow_critic
        for (p_origin, p_slow) in zip(self.critic_head.parameters(), self.slow_critic.parameters()):
            # Perform Exponential Moving Average (EMA) update safely
            p_slow.data.copy_(ema_decay * p_slow.data + (1 - ema_decay) * p_origin.data.detach())

# ======= onestep imagine ========#
    def propose_policy_for_one_step_imagine(self, initial_state, initial_pos, initial_cmd):
        cmd = initial_cmd[0]
        cmd_idx = torch.nonzero(cmd)[0, 0]

        _, policy = self.actor_rl(initial_state.detach(), initial_pos.detach(), cmd_idx)

        sample_shape = (self.actor_critic_config["sample_batch"],)

        sample_action = policy.rsample(sample_shape) # [N, B, T, 2]
        n, b, t, xy = sample_action.shape
        sample_action = rearrange(sample_action, 'n b t xy -> b n t xy') # [N, B, T, 2] => [B, N, T, 2]
        
        # [B, N, T, 2] => [1, B*N, T*2]
        wp_vector = rearrange(sample_action, 'b n t xy -> 1 (b n) (t xy)', b=b, n=n, t=t, xy=xy)

        # input of action mln: latent_query: (16, 1, 256) wp_vector: [1, B*N, T*2]
        action_query = self.pts_bbox_head.action_mln(initial_state.detach(), wp_vector.detach()) # [16, B*N, 256]
        action_pos = self.pts_bbox_head.pos_mln(initial_pos.detach(), wp_vector.detach()) # [16, B*N, 256]
        
        return policy, sample_action, action_query, action_pos

    def onestep_get_critic_reward(self, initial_state, sample_action, pred_state):
        '''
        initial_state (Tensor): [16, B, 256]
        sample_action (Tensor): [B, N, T, 2]
        pred_state (Tensor):    [16, B*N, 256]
        '''
        b, n, t, xy = sample_action.shape
        sample_action = rearrange(sample_action, 'b n t xy -> (b n) t xy', b=b, n=n, t=t, xy=xy)

        initial_state = initial_state.unsqueeze(2).expand(-1, -1, n, -1)
        initial_state = rearrange(initial_state, 'n_t b n d -> n_t (b n) d', b=b, n=n)

        critic_reward = self.critic_reward_head(state_tokens=initial_state, action=sample_action, next_state_tokens=pred_state).mean
        critic_reward = rearrange(critic_reward, '(b n) -> b n', b=b, n=n)

        return critic_reward # [B, N]

    def onestep_get_imitation_reward(self, initial_state, sample_action):
        '''
        initial_state (Tensor): [16, B, 256]
        sample_action (Tensor): [B, N, T, 2]
        '''        
        b, n, t, xy = sample_action.shape
        sample_action = rearrange(sample_action, 'b n t xy -> (b n) t xy', b=b, n=n, t=t, xy=xy)

        initial_state = initial_state.unsqueeze(2).expand(-1, -1, n, -1)
        initial_state = rearrange(initial_state, 'n_t b n d -> n_t (b n) d', b=b, n=n)

        imitation_reward = self.imitation_reward_head(state_tokens=initial_state, action=sample_action).mean
        imitation_reward = rearrange(imitation_reward, '(b n) -> b n', b=b, n=n)

        return imitation_reward # [B, N]


# ------
    def onestep_imagine(self, initial_state, initial_pos, initial_cmd):
        # sample_action (Tensor): [B, N, T, 2]
        # action_query (Tensor):  [16, B*N, 256]
        # action_pos   (Tensor):  [16, B*N, 256]
        policy, sample_action, action_query, action_pos = self.propose_policy_for_one_step_imagine(initial_state, initial_pos, initial_cmd)
        # pred_latent_query (Tensor) [16, B*N, 256]
        pred_latent_query, pred_latent_pos = self.imagine_fut_state(action_query, action_pos)

        return policy, sample_action, pred_latent_query, pred_latent_pos
    
    def onestep_obtain_target(self, initial_state, initial_pos, sample_action, pred_state, pred_state_pos):
        '''
        initial_state (Tensor): [16, B, 256]
        sample_action (Tensor): [B, N, T, 2]
        pred_state (Tensor):    [16, B*N, 256]
        '''
        intrinsic_reward = self.get_intrinsic_reward(initial_state) #[B,]
        critic_reward = self.onestep_get_critic_reward(initial_state, sample_action, pred_state) # [B, N]
        imitation_reward = self.onestep_get_imitation_reward(initial_state, sample_action)

        reward = intrinsic_reward.unsqueeze(1) + critic_reward + imitation_reward # [B, N]

        b, n = reward.shape
        current_value = self.get_value(initial_state).unsqueeze(1) #[B, 1]
        current_value = current_value.expand(-1, n)

        future_value = self.get_value(pred_state.detach()) # [B*N]
        future_value = rearrange(future_value, '(b n) -> b n', b=b, n=n)

        gamma_value = self.actor_critic_config["gamma_value"]
        lambda_value = self.actor_critic_config["lambda_value"]
        return_value = reward + gamma_value * (
            (1 - lambda_value) * current_value +
            lambda_value * future_value
        ) #[B, N]
        
        return return_value, current_value

    def onestep_compute_actor_loss(self, policy, sample_action, return_value, current_value):
        '''
        policy (MultiVraiantNormal): [B, T, 2]
        return_value (Tensor): [B, N]
        current_value (Tensor): [B, N]
        '''
        advantage = return_value - current_value # [B, N]
        advantage = (advantage - advantage.mean(dim=-1)) / (advantage.std(dim=-1) + 1e-6)
        # [B, N, T] => [B, N]
        actor_loss = policy.log_prob(sample_action.detach()).sum(dim=-1)
        actor_loss = advantage.detach() * actor_loss
        actor_loss = -actor_loss.mean(dim=-1) # [B,]

        return actor_loss
    
    def onestep_compute_critic_loss(self, imagine_dict):
        '''
        initial_state (Tensor): [16, B, 256]
        return_value (Tensor): [B, N]
        '''
        initial_state = imagine_dict['initial_state']
        initial_state_pos = imagine_dict['initial_state_pos']
        return_value = imagine_dict['return_value']

        value_distribution = self.get_value_distribution(initial_state.detach())
        critic_loss = -value_distribution.log_prob(return_value.detach()) # [B, N]
        critic_loss = critic_loss.mean(dim=-1)

        return critic_loss

# =======
    def critic_learning(self, imgagine_tree_dict, initial_cmd):
        if self.actor_critic_config['ema_regularization']:
            self.update_slow_critic()
        if self.actor_critic_config["one_step_imagine"]:
            critic_loss = self.onestep_compute_critic_loss(imgagine_tree_dict)
            return critic_loss.mean()
        else:
            critic_loss = self.compute_critic_loss(imgagine_tree_dict, initial_cmd)
            return critic_loss.mean()

    def actor_learning(self, initial_state, initial_cmd):
        if self.actor_critic_config['one_step_imagine']:
            policy, sample_action, pred_state, pred_state_pos = self.onestep_imagine(initial_state[0], initial_state[1], initial_cmd)
            return_value, current_value = self.onestep_obtain_target(initial_state[0], initial_state[1], sample_action, pred_state, pred_state_pos)
            actor_loss = self.onestep_compute_actor_loss(policy, sample_action, return_value, current_value)

            imagine_dict = {'initial_state': initial_state[0], 'initial_state_pos': initial_state[1], 'return_value': return_value}

            return actor_loss.mean(), imagine_dict, return_value.mean()
        else:
            # here, we only permit actor to record gradients
            max_length = self.actor_critic_config['max_imagine_length']
            imagine_tree_dict = self.imagine(initial_state, max_layer=max_length)

            imagine_tree_dict = self.obtain_target(imagine_tree_dict,two_reward=self.actor_critic_config['two_reward'])
            # TODO maybe average the loss with the node number
            actor_loss = self.compute_actor_loss(imagine_tree_dict, initial_cmd)

            return actor_loss.mean(), imagine_tree_dict

    def expert_teaching(self):
        with torch.no_grad():
            # teaching the reinforcement learning actor experience of expert
            for (p_actor_rl, p_actor_il) in zip(self.actor_rl.parameters(), self.pts_bbox_head.actor.parameters()):
                # Perform Exponential Moving Average (EMA) update
                p_actor_rl.data.copy_(0.5 * p_actor_rl.data + 0.5 * p_actor_il.data.detach())

    def explore_learning(self):
        with torch.no_grad():
            # teaching the imitation learning model the exploration experience of rl agent
            for (p_actor_il, p_actor_rl) in zip(self.pts_bbox_head.actor.parameters(), self.actor_rl.parameters()):
                # Perform Exponential Moving Average (EMA) update
                p_actor_il.data.copy_(0.8 * p_actor_il.data + 0.2 * p_actor_rl.data.detach())

    def copy_experience(self):
        with torch.no_grad():
            # teaching the reinforcement learning actor experience of expert
            for (p_actor_rl, p_actor_il) in zip(self.actor_rl.parameters(), self.pts_bbox_head.actor.parameters()):
                # Perform Exponential Moving Average (EMA) update
                p_actor_rl.data.copy_(p_actor_il.data.detach())

    def eval_action(self, proposed_action, gt_action, fut_masks, fut_cmd):
        imitation_reward = self.pts_bbox_head.obtain_imitation_reward_gt(
            ego_fut_preds=proposed_action,
            ego_fut_gt=gt_action,
            ego_fut_masks=fut_masks,
            ego_fut_cmd=fut_cmd,            
        )
        critic_reward = self.pts_bbox_head.obtain_critic_reward_gt(
            ego_fut_preds=proposed_action,
            ego_fut_gt=gt_action,
            ego_fut_masks=fut_masks,
            ego_fut_cmd=fut_cmd,
        )
        return (imitation_reward + critic_reward).mean()
# =======
    # overload this method in base class
    def _parse_losses(self, losses, grad_dict):
        """Parse the raw outputs (losses) of the network.

        Args:
            losses (dict): Raw output of the network, which usually contain
                losses and other necessary infomation.

        Returns:
            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
                which may be a weighted sum of all losses, log_vars contains \
                all the variables to be sent to the logger.
        """
        log_vars = OrderedDict()
        for loss_name, loss_value in losses.items():
            if loss_name in ["actor_learning_loss", "critic_learning_loss"]:
                continue
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError(
                    f'{loss_name} is not a tensor or list of tensors')

        loss = sum(_value for _key, _value in log_vars.items()
                   if 'loss' in _key)

        log_vars['loss_e2e'] = loss
        log_vars['grad_norm_e2e'] = grad_dict['grad_norm_e2e']

        # warm up stage do not have these items
        if 'actor_learning_loss' in losses.keys():
            log_vars['loss_actor_learning'] = losses['actor_learning_loss']
            log_vars['grad_norm_actor'] = grad_dict['grad_norm_actor']
        if 'critic_learning_loss' in losses.keys():
            log_vars['loss_critic_learning'] = losses['critic_learning_loss']
            log_vars['grad_norm_critic'] = grad_dict['grad_norm_critic']

        for loss_name, loss_value in log_vars.items():
            # reduce loss when distributed training
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()

        return loss, log_vars

    def train_step(self, data, optimizers, **kwargs):
        """The iteration step during training.

        This method defines an iteration step during training, except for the
        back propagation and optimizer updating, which are done in an optimizer
        hook. Note that in some complicated cases or models, the whole process
        including back propagation and optimizer updating is also defined in
        this method, such as GAN.

        Args:
            data (dict): The output of dataloader.
            optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
                runner is passed to ``train_step()``. This argument is unused
                and reserved.

        Returns:
            dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
                ``num_samples``.

                - ``loss`` is a tensor for back propagation, which can be a \
                weighted sum of multiple losses.
                - ``log_vars`` contains all the variables to be sent to the
                logger.
                - ``num_samples`` indicates the batch size (when the model is \
                DDP, it means the batch size on each GPU), which is used for \
                averaging the logs.
        """
        def clip_grad(optimizer, max_norm, norm_type=2):
            nn.utils.clip_grad_norm_(
                [p for group in optimizer.param_groups for p in group['params']], 
                max_norm=max_norm, 
                norm_type=norm_type
            )
        # we first close the require_grad for each module
        # then open then accordingly
        if self.actor_critic_config["enable_rl"]:
            initial_cmd = data['ego_fut_cmd'][0,2,0] # [B, 3]
            gt_ego_action = data['ego_fut_trajs'][0,2] # (B, 6, 2)
            ego_fut_masks = data['ego_fut_masks'][0,2,0] # (B, 6)
            B = initial_cmd.size(0)
            log_vars_update = {}
            if self.iter_counter == 0:
                self.il_agent_score = torch.zeros(B, device=initial_cmd.device)
                self.rl_agent_score = torch.zeros(B, device=initial_cmd.device)
        self.requires_grad_(requires_grad=False)
        # we first conduct the end 2end learning stage
        with RequiresGrad(self.module_index_dict["end2end_learning"]):
            losses, initial_state = self(**data)

            # Compute total loss
            end2end_learning_loss = sum(v for v in losses.values())

            # Backpropagation and optimization for optimizer_1
            optimizers['optimizer_1'].zero_grad()
            end2end_learning_loss.backward()

            # Gradient clipping
            clip_grad(optimizer=optimizers['optimizer_1'], max_norm=35, norm_type=2)

            optimizers['optimizer_1'].step()
            # we have to zero grad the actor module, as in actor-critic learning stage
            # the grad of actor will be updated
            grad_dict = {
                "grad_norm_e2e": torch.sqrt(sum(p.grad.norm(2)**2 for p in self.module_index_dict["end2end_learning"].parameters() if p.grad is not None)),
            }

        if self.actor_critic_config['enable_rl']:
            if self.iter_counter > self.actor_critic_config['cold_star_iter_num']:
                # iteration 1000-2000, we only train the value
                if self.iter_counter <= self.actor_critic_config['warmup_iter_num']:
                    with RequiresGrad(self.module_index_dict["critic_learning"]):
                        initial_state = (initial_state[0].detach(), initial_state[1].detach())
                        _, imagine_tree_dict, return_value = self.actor_learning(initial_state, initial_cmd)
                        critic_loss = self.critic_learning(imagine_tree_dict, initial_cmd)
                    
                    # optimize the parameters
                    with RequiresGrad(self.module_index_dict["critic_learning"]):
                        # then optimize the critic
                        optimizers["optimizer_2_critic"].zero_grad()
                        critic_loss.backward()
                        # Gradient clipping
                        clip_grad(optimizer=optimizers["optimizer_2_critic"], max_norm=35, norm_type=2)
                        optimizers["optimizer_2_critic"].step()

                    losses["critic_learning_loss"] = critic_loss
                    grad_dict["grad_norm_critic"] = torch.sqrt(sum(p.grad.norm(2)**2 for p in self.module_index_dict["critic_learning"].parameters() if p.grad is not None))
                else:
                    if not self.finish_warmup_flag:
                        self.copy_experience()
                        self.finish_warmup_flag = True
                    
                        # actor learning stage
                    with RequiresGrad(self.module_index_dict["actor_learning"]):
                        # attention, the initial_state is detached from the original computation graph
                        # so backward propagation do not affect the perception module
                        # scene_query (latent_query) scene_pos (latent_pos)
                        initial_state = (initial_state[0].detach(), initial_state[1].detach())
                        actor_loss, imagine_tree_dict, return_value = self.actor_learning(initial_state, initial_cmd)
                    
                    # critic learning stage
                    with RequiresGrad(self.module_index_dict["critic_learning"]):
                        critic_loss = self.critic_learning(imagine_tree_dict, initial_cmd)
                    
                    # optimize the parameters
                    with RequiresGrad([self.module_index_dict["actor_learning"], self.module_index_dict["critic_learning"]]):
                        # first optimize the actor
                        optimizers["optimizer_2_actor"].zero_grad()
                        actor_loss.backward()
                        # Gradient clipping
                        clip_grad(optimizer=optimizers['optimizer_2_actor'], max_norm=35, norm_type=2)
                        optimizers["optimizer_2_actor"].step()

                        # then optimize the critic
                        optimizers["optimizer_2_critic"].zero_grad()
                        critic_loss.backward()
                        # Gradient clipping
                        clip_grad(optimizer=optimizers["optimizer_2_critic"], max_norm=35, norm_type=2)
                        optimizers["optimizer_2_critic"].step()

                    losses["actor_learning_loss"] = actor_loss
                    losses["critic_learning_loss"] = critic_loss

                    log_vars_update["return"] = return_value.item()

                    grad_dict["grad_norm_actor"] = torch.sqrt(sum(p.grad.norm(2)**2 for p in self.module_index_dict["actor_learning"].parameters() if p.grad is not None))
                    grad_dict["grad_norm_critic"] = torch.sqrt(sum(p.grad.norm(2)**2 for p in self.module_index_dict["critic_learning"].parameters() if p.grad is not None))

                    with torch.no_grad():
                        # now judge which one is better: actor_il or actor_rl
                        il_agent_action = self.imitation_learning_actor_planning(initial_state[0], initial_state[1]) # [B, 3, T, 2]
                        rl_agent_action = self.reinforcement_learning_actor_planning(initial_state[0], initial_state[1]) # [B, 3, T, 2]

                        self.il_agent_score += self.eval_action(il_agent_action, gt_ego_action, ego_fut_masks, initial_cmd)
                        self.rl_agent_score += self.eval_action(rl_agent_action, gt_ego_action, ego_fut_masks, initial_cmd)

                        if self.iter_counter % self.actor_critic_config["competetion_interval"] == 0:
                            self.il_agent_score = self.il_agent_score.mean()
                            self.rl_agent_score = self.rl_agent_score.mean()

                            # reduce the score when distributed training
                            if dist.is_available() and dist.is_initialized():
                                dist.all_reduce(self.il_agent_score)
                                dist.all_reduce(self.rl_agent_score)
                                self.il_agent_score /= dist.get_world_size()
                                self.rl_agent_score /= dist.get_world_size()
                            

                            diff = abs((self.il_agent_score - self.rl_agent_score).item())
                            if diff > self.actor_critic_config["score_threshold"]:
                                if self.il_agent_score >= self.rl_agent_score:
                                    self.num_il_agent_win += 1
                                    self.expert_teaching()
                                else:
                                    self.num_rl_agent_win += 1
                                    self.explore_learning()
                                    self.copy_experience()
                            else:
                                self.num_il_agent_win += 1
                                # self.copy_experience()
                                self.explore_learning()
                            
                            log_vars_update["il_win"] = self.num_il_agent_win
                            log_vars_update["rl_win"] = self.num_rl_agent_win
                            log_vars_update["il_score"] = self.il_agent_score.item()
                            log_vars_update["rl_score"] = self.rl_agent_score.item()

                            self.il_agent_score = torch.zeros(B, device=initial_cmd.device)
                            self.rl_agent_score = torch.zeros(B, device=initial_cmd.device)
        

        # parse the loss
        loss, log_vars = self._parse_losses(losses, grad_dict)
        if self.actor_critic_config["enable_rl"] and log_vars_update:
            log_vars.update(log_vars_update)
        outputs = dict(
            loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))

        if self.actor_critic_config["enable_rl"]:
            self.iter_counter += 1

        return outputs

    def get_modules_by_optimizer(self):
        """
        Returns:
            dict: A dictionary with keys:
                - 'actor_learning': actor_rl module (for RL optimizer)
                - 'critic_learning': critic_head (for value learning)
                - 'end2end_learning': all other modules (perception, world model, reward, etc.)
        """
        # These modules should be excluded from end-to-end optimizer
        excluded_modules = {id(self.actor_rl), id(self.critic_head), id(self.slow_critic) if hasattr(self, 'slow_critic') else None}

        end2end_modules = nn.Module()
        
        for name, module in self.named_children():
            if id(module) in excluded_modules or module is None:
                continue
            end2end_modules.add_module(name, module)

        return {
            'actor_learning': self.actor_rl,
            'critic_learning': self.critic_head,
            'end2end_learning': end2end_modules,
        }
