"""
                    NuerIPS 2025 submission
                      Anonymous Author(s)
We modified the code according to SSR and retain the copyright statement here.
====================================================================================
 Copyright (c) Zhijia Technology. All rights reserved.
 
 Author: Peidong Li (lipeidong@smartxtruck.com / peidongl@outlook.com)
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
 
     http://www.apache.org/licenses/LICENSE-2.0
 
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 =================================================================================
"""
import copy
from math import pi, cos, sin
import os
import torch
import numpy as np
import torch.nn as nn
from einops import rearrange
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.distributions import MultivariateNormal, kl_divergence
from mmdet.models import HEADS, build_loss 
from mmdet.models.dense_heads import DETRHead
from mmcv.runner import force_fp32, auto_fp16
from mmcv.utils import TORCH_VERSION, digit_version
from mmdet.core import build_assigner, build_sampler
from mmdet3d.core.bbox.coders import build_bbox_coder
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.core.bbox.transforms import bbox_xyxy_to_cxcywh
from mmcv.cnn import Linear, bias_init_with_prob, xavier_init
from mmdet.core import (multi_apply, multi_apply, reduce_mean)
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence

from projects.mmdet3d_plugin.core.bbox.util import normalize_bbox
from projects.mmdet3d_plugin.SSR.utils.traj_lr_warmup import get_traj_warmup_loss_weight
from projects.mmdet3d_plugin.SSR.utils.map_utils import (
    normalize_2d_pts, normalize_2d_bbox, denormalize_2d_pts, denormalize_2d_bbox
)
from projects.mmdet3d_plugin.SSR.utils.plan_loss import CriticImitationConstrain, CriticEndPointConstrain, CriticCollisionConstrain, IntrinsicMapBoundConstrain, CriticMapBoundConstrain, IntrinsicCollisionConstrain
from .tokenlearner import *
from mmdet.models.utils import LearnedPositionalEncoding
#from projects.mmdet3d_plugin.VAD.VAD_head import get_targets

class MLN(nn.Module):
    ''' 
    from "https://github.com/exiawsh/StreamPETR"
    Args:
        c_dim (int): dimension of latent code c
        f_dim (int): feature dimension
    '''

    def __init__(self, c_dim, f_dim=256, use_ln=True):
        super().__init__()
        self.c_dim = c_dim
        self.f_dim = f_dim
        self.use_ln = use_ln

        self.reduce = nn.Sequential(
            nn.Linear(c_dim, f_dim),
            nn.ReLU(),
        )
        self.gamma = nn.Linear(f_dim, f_dim)
        self.beta = nn.Linear(f_dim, f_dim)
        if self.use_ln:
            self.ln = nn.LayerNorm(f_dim, elementwise_affine=False)
        self.init_weight()

    def init_weight(self):
        nn.init.zeros_(self.gamma.weight)
        nn.init.zeros_(self.beta.weight)
        nn.init.ones_(self.gamma.bias)
        nn.init.zeros_(self.beta.bias)

    def forward(self, x, c):
        if self.use_ln:
            x = self.ln(x)
        c = self.reduce(c)
        gamma = self.gamma(c)
        beta = self.beta(c)
        out = gamma * x + beta

        return out

class SELayer(nn.Module):

    def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid):
        super().__init__()
        self.mlp_reduce = nn.Linear(channels, channels)
        self.act1 = act_layer()
        self.mlp_expand = nn.Linear(channels, channels)
        self.gate = gate_layer()

    def forward(self, x, x_se):
        x_se = self.mlp_reduce(x_se)
        x_se = self.act1(x_se)
        x_se = self.mlp_expand(x_se)
        return x * self.gate(x_se)

class Actor(nn.Module):
    def __init__(self, ego_fut_mode, fut_ts, embed_dims, way_decoder, ego_lcf_feat_idx, mlp_layers, policy_config):
        super().__init__()
        self.ego_fut_mode = ego_fut_mode
        self.fut_ts = fut_ts
        self.embed_dims = embed_dims
        self.way_decoder = way_decoder
        self.ego_lcf_feat_idx = ego_lcf_feat_idx
        self.mlp_layers = mlp_layers
        self.policy_config = policy_config

        self.way_point = nn.Embedding(self.ego_fut_mode*self.fut_ts, self.embed_dims * 2)
        self.way_decoder = build_transformer_layer_sequence(self.way_decoder)

        self.causal_mask = torch.tril(torch.ones(self.fut_ts, self.fut_ts), diagonal=-1).bool()
        self.auto_regression_attention = nn.MultiheadAttention(embed_dim=self.embed_dims, num_heads=8, batch_first=False)

        # model the actions from determined waypoints into stochastic distribution
        # use 2-d normal distribution to model the future trajectory of ego vehicle
        # mean layer (3 layer mlp): [256, 256, 256, 2]
        ego_fut_decoder = []
        ego_fut_dec_in_dim = self.embed_dims + len(self.ego_lcf_feat_idx) \
            if self.ego_lcf_feat_idx is not None else self.embed_dims
        for _ in range(self.mlp_layers):
            ego_fut_decoder.append(Linear(ego_fut_dec_in_dim, ego_fut_dec_in_dim))
            ego_fut_decoder.append(nn.ReLU())
        ego_fut_decoder.append(Linear(ego_fut_dec_in_dim, 2))
        self.ego_fut_decoder = nn.Sequential(*ego_fut_decoder)
        
        # std layer (3 layer mlp): [256, 256, 256, 2]
        ego_fut_std_decoder = []
        for _ in range(self.mlp_layers):
            ego_fut_std_decoder.append(Linear(ego_fut_dec_in_dim, ego_fut_dec_in_dim))
            ego_fut_std_decoder.append(nn.ReLU())
        ego_fut_std_decoder.append(Linear(ego_fut_dec_in_dim, 2))
        ego_fut_std_decoder.append(nn.Softplus()) # Ensure positivity
        self.ego_fut_std_decoder = nn.Sequential(*ego_fut_std_decoder)

        # rou layer (3 layer mlp): [256, 256, 256, 1]
        ego_fut_rou_decoder = []
        for _ in range(self.mlp_layers):
            ego_fut_rou_decoder.append(Linear(ego_fut_dec_in_dim, ego_fut_dec_in_dim))
            ego_fut_rou_decoder.append(nn.ReLU())
        ego_fut_rou_decoder.append(Linear(ego_fut_dec_in_dim, 1))
        ego_fut_rou_decoder.append(nn.Tanh()) # Constrain to [-1, 1]
        self.ego_fut_rou_decoder = nn.Sequential(*ego_fut_rou_decoder)

    def init_weights(self):
        if self.way_decoder is not None:
            for p in self.way_decoder.parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)
        if self.auto_regression_attention is not None:
            for p in self.auto_regression_attention.parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)

    def get_distribution(self, mu, sigma, rou):
        """
        Args:
            mu: (1, 3, 6, 2)
            sigma: (1, 3, 6, 2)
            rou: (1, 3, 6, 1)
        Returns:
            cov_matrices (Tensor): Covariance matrices, shape (1, 3, 6, 2, 2).
            distribution (list[MultivariateNormal]): List of distributions for each timestamp.
        """
        # TODO try dreamerV3's method https://github.com/NM512/dreamerv3-torch/blob/7433d1e87747ff574cdeffb5820f2a2647d212bd/networks.py#L693-L696
        # copilot's suggestion: prevent the covariance matrix from very small values and very large values
        # TODO experiment with different min/max sigma values
        dtype = mu.dtype
        x_min_std = self.policy_config['x_min_std']
        x_max_std = self.policy_config['x_max_std']
        y_min_std = self.policy_config['y_min_std']
        y_max_std = self.policy_config['y_max_std']

        sigma_clamped = torch.stack([
            torch.clamp(sigma[:, :, :, 0], min=x_min_std, max=x_max_std),
            torch.clamp(sigma[:, :, :, 1], min=y_min_std, max=y_max_std)
        ], dim=-1)

        # Clamp rou to avoid perfect correlations
        rou = torch.clamp(rou, min=-0.75, max=0.75)

        # we first compute the covariance matrix
        covariance_matrix = torch.zeros((1, 3, 6, 2, 2), device=sigma_clamped.device, dtype=dtype)
        covariance_matrix[:, :, :, 0, 0] = sigma_clamped[:, :, :, 0] ** 2
        covariance_matrix[:, :, :, 1, 1] = sigma_clamped[:, :, :, 1] ** 2
        covariance_matrix[:, :, :, 0, 1] = rou[:, :, :, 0] * sigma_clamped[:, :, :, 0] * sigma_clamped[:, :, :, 1]
        covariance_matrix[:, :, :, 1, 0] = covariance_matrix[:, :, :, 0, 1]

        # Add a small value to the diagonal for numerical stability
        covariance_matrix[:, :, :, 0, 0] += 1e-6
        covariance_matrix[:, :, :, 1, 1] += 1e-6

        # create a batched multivariate normal distribution
        distribution = MultivariateNormal(mu, covariance_matrix)
        return distribution
    
    def get_single_model_distribution(self, mu, sigma, rou, enhance_exploration=False):
        """
        Args:
            mu: (1, 6, 2)
            sigma: (1, 6, 2)
            rou: (1, 6, 1)
        Returns:
            cov_matrices (Tensor): Covariance matrices, shape (1, 6, 2, 2).
            distribution (list[MultivariateNormal]): List of distributions for each timestamp.
        """
        assert mu.dim() == 3, "this method can only obtain distribution of a single mode"
        # TODO try dreamerV3's method https://github.com/NM512/dreamerv3-torch/blob/7433d1e87747ff574cdeffb5820f2a2647d212bd/networks.py#L693-L696
        # copilot's suggestion: prevent the covariance matrix from very small values and very large values
        # TODO experiment with different min/max sigma values
        # at reinforcement learning stage, we encouage the model to explore, so here we larger the trajectory sigma than in imitation learning
        # Safely clamp sigma without in-place ops
        enhance_ratio = 2.0 if enhance_exploration else 1.0
        explore_ratio = self.policy_config['explore_ratio']
        x_min_std = self.policy_config['x_min_std'] * explore_ratio * enhance_ratio
        x_max_std = self.policy_config['x_max_std'] * explore_ratio * enhance_ratio
        y_min_std = self.policy_config['y_min_std'] * explore_ratio * enhance_ratio
        y_max_std = self.policy_config['y_max_std'] * explore_ratio * enhance_ratio

        sigma_clamped = torch.stack([
            torch.clamp(sigma[:, :, 0], min=x_min_std, max=x_max_std),
            torch.clamp(sigma[:, :, 1], min=y_min_std, max=y_max_std)
        ], dim=2)

        # Clamp rou to avoid perfect correlations
        rou = torch.clamp(rou, min=-0.75, max=0.75)

        # we first compute the covariance matrix
        covariance_matrix = torch.zeros((1, 6, 2, 2), device=sigma_clamped.device)
        covariance_matrix[:, :, 0, 0] = sigma_clamped[:, :, 0] ** 2
        covariance_matrix[:, :, 1, 1] = sigma_clamped[:, :, 1] ** 2
        covariance_matrix[:, :, 0, 1] = rou[:, :, 0] * sigma_clamped[:, :, 0] * sigma_clamped[:, :, 1]
        covariance_matrix[:, :, 1, 0] = covariance_matrix[:, :, 0, 1]
        
        # Add a small value to the diagonal for numerical stability
        covariance_matrix[:, :, 0, 0] += 1e-6
        covariance_matrix[:, :, 1, 1] += 1e-6

        # Create a batched multivariate normal distribution
        distribution = MultivariateNormal(mu, covariance_matrix)
        return distribution

    def forward(self, latent_query, latent_pos, cmd_idx, require_explore_policy=False):
        """
        latent_query (Tensor): [16, B, 256]
        latent_pos (Tensor): [16, B, 256]
        """
        dtype = latent_query.dtype
        bs = latent_query.size(1)

        way_point = self.way_point.weight.to(dtype)
        wp_pos, way_point = torch.split(
            way_point, self.embed_dims, dim=1)

        wp_pos = wp_pos.unsqueeze(0).expand(bs, -1, -1)
        way_point = way_point.unsqueeze(0).expand(bs, -1, -1)
        wp_pos = wp_pos.permute(1, 0, 2)
        way_point = way_point.permute(1, 0, 2)
        # way_point: (18, 1, 256) latent_query: (16, 1, 256)
        way_point = self.way_decoder(
                query=way_point,
                key=latent_query,
                value=latent_query,
                query_pos=wp_pos,
                key_pos=latent_pos
        )
        
        # here we use auto-regression attention to process the scene query
        # then decode the waypoints to get the future trajectory of ego vehicle
        # (18, 1, 256) => (1, 18, 256) => (1, 3, 6, 256) => (6, 3, 256)
        way_point = rearrange(way_point, '(d t) b c -> t (d b) c', d=self.ego_fut_mode, t=self.fut_ts)
        # (6, 3, 256) => (6, 3, 256)
        # TODO more elegant way to do this
        self.causal_mask = self.causal_mask.to(way_point.device)
        way_point, _ = self.auto_regression_attention(
                query=way_point,
                key=way_point,
                value=way_point,
                attn_mask=self.causal_mask,
                need_weights=False)
        # (6, 3, 256) => (1, 3, 6, 256)
        way_point = rearrange(way_point, 't (d b) c -> b d t c', d=self.ego_fut_mode, b=bs)
        # ego_fut_decoder is a 3 layer mlp
        # shape of outputs_ego_trajs: (1, 3, 6, 2)
        outputs_ego_trajs = self.ego_fut_decoder(way_point)
        # here we should let the generated way_point in the abs space not relevant space
        # outputs_ego_trajs_final = torch.zeros_like(outputs_ego_trajs)
        # outputs_ego_trajs_final[:,:,0,:] = outputs_ego_trajs[:,:,0,:]
        # outputs_ego_trajs_final[:,:,1:,:] = outputs_ego_trajs[:,:,1:,:] - outputs_ego_trajs[:,:,:-1,:]
        # outputs_ego_trajs = outputs_ego_trajs_final
        # (18, 1, 2) => (1, 18, 2) => (1, 3, 6, 2)
        # outputs_ego_trajs = outputs_ego_trajs.permute(1, 0, 2).view(bs, 
        #                                               self.ego_fut_mode, self.fut_ts, 2)
        outputs_ego_trajs_fut=outputs_ego_trajs[:,cmd_idx,...] # (1, 6, 2)
        # get std (1, 3, 6, 2)
        outputs_ego_trajs_std = self.ego_fut_std_decoder(way_point)
        # outputs_ego_trajs_std = outputs_ego_trajs_std.permute(1, 0, 2).view(bs,
        #                                               self.ego_fut_mode, self.fut_ts, 2) # (1, 3, 6, 2)
        outputs_ego_trajs_std_fut = outputs_ego_trajs_std[:,cmd_idx,...] # (1, 6, 2)
        # get rou (1, 3, 6, 1)
        outputs_ego_trajs_rou = self.ego_fut_rou_decoder(way_point)
        # outputs_ego_trajs_rou = outputs_ego_trajs_rou.permute(1, 0, 2).view(bs,
        #                                                 self.ego_fut_mode, self.fut_ts, 1) # (1, 3, 6, 1)
        outputs_ego_trajs_rou_fut = outputs_ego_trajs_rou[:,cmd_idx,...] # (1, 6, 1)

        # get the trajectory distribution
        trajectory_distribution = self.get_distribution(
            outputs_ego_trajs, outputs_ego_trajs_std, outputs_ego_trajs_rou)
        
        policy = self.get_single_model_distribution(mu=outputs_ego_trajs_fut, sigma=outputs_ego_trajs_std_fut, rou=outputs_ego_trajs_rou_fut)

        if require_explore_policy:
            explore_policy = self.get_single_model_distribution(mu=outputs_ego_trajs_fut, sigma=outputs_ego_trajs_std_fut, rou=outputs_ego_trajs_rou_fut, enhance_exploration=True)
            return trajectory_distribution, policy, explore_policy
        else:
            return trajectory_distribution, policy

    def planning(self, latent_query, latent_pos):
        bs = latent_query.size(1)
        way_point = self.way_point.weight.to(latent_query.dtype)
        wp_pos, way_point = torch.split(way_point, self.embed_dims, dim=1)

        wp_pos = wp_pos.unsqueeze(0).expand(bs, -1, -1)
        way_point = way_point.unsqueeze(0).expand(bs, -1, -1)
        wp_pos = wp_pos.permute(1, 0, 2)
        way_point = way_point.permute(1, 0, 2)
        # way_point: (18, 1, 256) latent_query: (16, 1, 256)
        # The problem here is we cannot obtain the latent_pos
        # so, we need to let the world model generate this item!
        way_point = self.way_decoder(
                query=way_point,
                key=latent_query,
                value=latent_query,
                query_pos=wp_pos,
                key_pos=latent_pos)
        
        # here we use auto-regression attention to process the scene query
        # then decode the waypoints to get the future trajectory of ego vehicle
        # (18, 1, 256) => (1, 18, 256) => (1, 3, 6, 256) => (6, 3, 256)
        way_point = rearrange(way_point, '(d t) b c -> t (d b) c', d=self.ego_fut_mode, t=self.fut_ts)
        # (6, 3, 256) => (6, 3, 256)
        # TODO more elegant way to do this
        causal_mask = self.causal_mask.to(way_point.device)
        way_point, _ = self.auto_regression_attention(
                query=way_point,
                key=way_point,
                value=way_point,
                attn_mask=causal_mask,
                need_weights=False)
        # (6, 3, 256) => (1, 3, 6, 256)
        way_point = rearrange(way_point, 't (d b) c -> b d t c', d=3, b=bs)
        # ego_fut_decoder is a 3 layer mlp
        # shape of outputs_ego_trajs: (1, 3, 6, 2)
        outputs_ego_trajs = self.ego_fut_decoder(way_point)

        return outputs_ego_trajs

@HEADS.register_module()
class SSRHead(DETRHead):
    """Head of SSR model.
    Args:
        with_box_refine (bool): Whether to refine the reference points
            in the decoder. Defaults to False.
        as_two_stage (bool) : Whether to generate the proposal from
            the outputs of encoder.
        transformer (obj:`ConfigDict`): ConfigDict is used for building
            the Encoder and Decoder.
        bev_h, bev_w (int): spatial shape of BEV queries.
    """
    def __init__(self,
                 *args,
                 with_box_refine=False,
                 as_two_stage=False,
                 transformer=None,
                 bbox_coder=None,
                 num_cls_fcs=2,
                 code_weights=None,
                 bev_h=30,
                 bev_w=30,
                 fut_ts=6,
                 fut_mode=6,
                 loss_traj=dict(type='L1Loss', loss_weight=0.25),
                 loss_traj_cls=dict(
                     type='FocalLoss',
                     use_sigmoid=True,
                     gamma=2.0,
                     alpha=0.25,
                     loss_weight=0.8),
                 map_bbox_coder=None,
                 map_num_query=900,
                 map_num_classes=3,
                 map_num_vec=20,
                 map_num_pts_per_vec=2,
                 map_num_pts_per_gt_vec=2,
                 map_query_embed_type='all_pts',
                 map_transform_method='minmax',
                 map_gt_shift_pts_pattern='v0',
                 map_dir_interval=1,
                 map_code_size=None,
                 map_code_weights=None,
                 loss_map_cls=dict(
                     type='CrossEntropyLoss',
                     bg_cls_weight=0.1,
                     use_sigmoid=False,
                     loss_weight=1.0,
                     class_weight=1.0),
                 loss_map_bbox=dict(type='L1Loss', loss_weight=5.0),
                 loss_map_iou=dict(type='GIoULoss', loss_weight=2.0),
                 loss_map_pts=dict(
                    type='ChamferDistance',loss_src_weight=1.0,loss_dst_weight=1.0
                 ),
                 loss_map_dir=dict(type='PtsDirCosLoss', loss_weight=2.0),
                 num_scenes=16,
                 latent_decoder=None,
                 way_decoder=None,
                 ego_fut_mode=3,
                 loss_plan_reg=dict(type='L1Loss', loss_weight=0.25),
                 critic_reward_bound=None,
                 critic_reward_col=None,
                 intrinsic_reward_bound=None,
                 intrinsic_reward_col=None,
                 ego_lcf_feat_idx=None,
                 valid_fut_ts=6,
                 reward_model_augmentation=False,
                 policy_config=None,
                 **kwargs):

        self.bev_h = bev_h
        self.bev_w = bev_w
        self.fp16_enabled = False
        self.fut_ts = fut_ts
        self.fut_mode = fut_mode

        self.latent_decoder = latent_decoder
        self.way_decoder = way_decoder

        self.ego_fut_mode = ego_fut_mode
        self.ego_lcf_feat_idx = ego_lcf_feat_idx
        self.valid_fut_ts = valid_fut_ts
        self.num_scenes = num_scenes

        if loss_traj_cls['use_sigmoid'] == True:
            self.traj_num_cls = 1
        else:
          self.traj_num_cls = 2

        self.with_box_refine = with_box_refine
        self.as_two_stage = as_two_stage
        if self.as_two_stage:
            transformer['as_two_stage'] = self.as_two_stage
        if 'code_size' in kwargs:
            self.code_size = kwargs['code_size']
        else:
            self.code_size = 10
        if code_weights is not None:
            self.code_weights = code_weights
        else:
            self.code_weights = [1.0, 1.0, 1.0,
                                 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
        if map_code_size is not None:
            self.map_code_size = map_code_size
        else:
            self.map_code_size = 10
        if map_code_weights is not None:
            self.map_code_weights = map_code_weights
        else:
            self.map_code_weights = [1.0, 1.0, 1.0,
                                 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]

        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.pc_range = self.bbox_coder.pc_range
        self.real_w = self.pc_range[3] - self.pc_range[0]
        self.real_h = self.pc_range[4] - self.pc_range[1]

        self.map_query_embed_type = map_query_embed_type
        self.map_num_vec = map_num_vec
        self.map_num_pts_per_vec = map_num_pts_per_vec

        if loss_map_cls['use_sigmoid'] == True:
            self.map_cls_out_channels = map_num_classes
        else:
            self.map_cls_out_channels = map_num_classes + 1

        super(SSRHead, self).__init__(*args, transformer=transformer, **kwargs)

        self.policy_config = policy_config

        self.actor = Actor(self.ego_fut_mode, self.fut_ts, self.embed_dims, self.way_decoder, self.ego_lcf_feat_idx, self.num_reg_fcs, self.policy_config)

        self.code_weights = nn.Parameter(torch.tensor(
            self.code_weights, requires_grad=False), requires_grad=False)
        self.map_code_weights = nn.Parameter(torch.tensor(
            self.map_code_weights, requires_grad=False), requires_grad=False)

        self.loss_plan_reg = build_loss(loss_plan_reg)

        self.critic_reward_imitation_op = CriticImitationConstrain()
       
        self.critic_reward_bound_op = CriticMapBoundConstrain(**critic_reward_bound)
        self.critic_reward_col_op = CriticCollisionConstrain(**critic_reward_col)


        self.intrinsic_reward_bound_op = IntrinsicMapBoundConstrain(**intrinsic_reward_bound)
        self.intrinsic_reward_col_op = IntrinsicCollisionConstrain(**intrinsic_reward_col)

        self.reward_model_augmentation = reward_model_augmentation
        # if self.reward_model_augmentation:
        #     # in order to reinforce the reward model, we here add some perturbation to the trajectory
        #     # but these changed trajectory only affect the reward model, not affect other module
        #     dx_list = [-1, -0.5, 0., 0.5, 1]
        #     dy_list = [-2, -1, 0., 1, 2]
        #     self.way_point_perturbation = torch.ones((24, 6, 2), dtype=torch.float32)
        #     combination_index_list = []
        #     for i, dx in enumerate(dx_list):
        #         for j, dy in enumerate(dy_list):
        #             if dx == 0. and dy == 0.:
        #                 continue
        #             combination_index_list.append((dx, dy))
        #     for i, (dx, dy) in enumerate(combination_index_list):
        #         self.way_point_perturbation[i,:,0] *= dx
        #         self.way_point_perturbation[i,:,1] *= dy


    def _init_layers(self):
        """Initialize classification branch and regression branch of head."""
        cls_branch = []
        for _ in range(self.num_reg_fcs):
            cls_branch.append(Linear(self.embed_dims, self.embed_dims))
            cls_branch.append(nn.LayerNorm(self.embed_dims))
            cls_branch.append(nn.ReLU(inplace=True))
        cls_branch.append(Linear(self.embed_dims, self.cls_out_channels))
        cls_branch = nn.Sequential(*cls_branch)

        reg_branch = []
        for _ in range(self.num_reg_fcs):
            reg_branch.append(Linear(self.embed_dims, self.embed_dims))
            reg_branch.append(nn.ReLU())
        reg_branch.append(Linear(self.embed_dims, self.code_size))
        reg_branch = nn.Sequential(*reg_branch)

        traj_branch = []
        for _ in range(self.num_reg_fcs):
            traj_branch.append(Linear(self.embed_dims*2, self.embed_dims*2))
            traj_branch.append(nn.ReLU())
        traj_branch.append(Linear(self.embed_dims*2, self.fut_ts*2))
        traj_branch = nn.Sequential(*traj_branch)

        traj_cls_branch = []
        for _ in range(self.num_reg_fcs):
            traj_cls_branch.append(Linear(self.embed_dims*2, self.embed_dims*2))
            traj_cls_branch.append(nn.LayerNorm(self.embed_dims*2))
            traj_cls_branch.append(nn.ReLU(inplace=True))
        traj_cls_branch.append(Linear(self.embed_dims*2, self.traj_num_cls))
        traj_cls_branch = nn.Sequential(*traj_cls_branch)

        map_cls_branch = []
        for _ in range(self.num_reg_fcs):
            map_cls_branch.append(Linear(self.embed_dims, self.embed_dims))
            map_cls_branch.append(nn.LayerNorm(self.embed_dims))
            map_cls_branch.append(nn.ReLU(inplace=True))
        map_cls_branch.append(Linear(self.embed_dims, self.map_cls_out_channels))
        map_cls_branch = nn.Sequential(*map_cls_branch)

        map_reg_branch = []
        for _ in range(self.num_reg_fcs):
            map_reg_branch.append(Linear(self.embed_dims, self.embed_dims))
            map_reg_branch.append(nn.ReLU())
        map_reg_branch.append(Linear(self.embed_dims, self.map_code_size))
        map_reg_branch = nn.Sequential(*map_reg_branch)


        def _get_clones(module, N):
            return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

        # last reg_branch is used to generate proposal from
        # encode feature map when as_two_stage is True.
        num_decoder_layers = 1
        num_map_decoder_layers = 1
        if self.transformer.decoder is not None:
            num_decoder_layers = self.transformer.decoder.num_layers
        if self.transformer.map_decoder is not None:
            num_map_decoder_layers = self.transformer.map_decoder.num_layers
        num_motion_decoder_layers = 1
        num_pred = (num_decoder_layers + 1) if \
            self.as_two_stage else num_decoder_layers
        motion_num_pred = (num_motion_decoder_layers + 1) if \
            self.as_two_stage else num_motion_decoder_layers
        map_num_pred = (num_map_decoder_layers + 1) if \
            self.as_two_stage else num_map_decoder_layers

        if self.with_box_refine:
            self.cls_branches = _get_clones(cls_branch, num_pred)
            self.reg_branches = _get_clones(reg_branch, num_pred)
            self.traj_branches = _get_clones(traj_branch, motion_num_pred)
            self.traj_cls_branches = _get_clones(traj_cls_branch, motion_num_pred)
            self.map_cls_branches = _get_clones(map_cls_branch, map_num_pred)
            self.map_reg_branches = _get_clones(map_reg_branch, map_num_pred)
        else:
            self.cls_branches = nn.ModuleList(
                [cls_branch for _ in range(num_pred)])
            self.reg_branches = nn.ModuleList(
                [reg_branch for _ in range(num_pred)])
            self.traj_branches = nn.ModuleList(
                [traj_branch for _ in range(motion_num_pred)])
            self.traj_cls_branches = nn.ModuleList(
                [traj_cls_branch for _ in range(motion_num_pred)])
            self.map_cls_branches = nn.ModuleList(
                [map_cls_branch for _ in range(map_num_pred)])
            self.map_reg_branches = nn.ModuleList(
                [map_reg_branch for _ in range(map_num_pred)])

        if not self.as_two_stage:
            self.bev_embedding = nn.Embedding(
                self.bev_h * self.bev_w, self.embed_dims)
            self.query_embedding = nn.Embedding(self.num_query,
                                                self.embed_dims * 2)
            if self.map_query_embed_type == 'all_pts':
                self.map_query_embedding = nn.Embedding(self.map_num_query,
                                                    self.embed_dims * 2)
            elif self.map_query_embed_type == 'instance_pts':
                self.map_query_embedding = None
                self.map_instance_embedding = nn.Embedding(self.map_num_vec, self.embed_dims * 2)
                self.map_pts_embedding = nn.Embedding(self.map_num_pts_per_vec, self.embed_dims * 2)
        
        self.ego_query = nn.Embedding(1, self.embed_dims)	

        # Here, before the ego vehicle's future trajectory distribution is predicted,
        # we add auto-regression matter to the future trajectory of the ego vehicle.
        # so, we use masked attention to process the feature of the scene query.

        # TODO experiment
        # a) auto-regression attention: use triangular upper matrix
        # b) inverse auto-regression attention: use triangular lower matrix
        # self.causal_mask = torch.triu(torch.ones(self.fut_ts, self.fut_ts), diagonal=1).bool()

        # self.causal_mask = torch.tril(torch.ones(self.fut_ts, self.fut_ts), diagonal=-1).bool()
        # self.auto_regression_attention = nn.MultiheadAttention(embed_dim=self.embed_dims, num_heads=8, batch_first=False)

        # # model the actions from determined waypoints into stochastic distribution
        # # use 2-d normal distribution to model the future trajectory of ego vehicle
        # # mean layer (3 layer mlp): [256, 256, 256, 2]
        # ego_fut_decoder = []
        # ego_fut_dec_in_dim = self.embed_dims + len(self.ego_lcf_feat_idx) \
        #     if self.ego_lcf_feat_idx is not None else self.embed_dims
        # for _ in range(self.num_reg_fcs):
        #     ego_fut_decoder.append(Linear(ego_fut_dec_in_dim, ego_fut_dec_in_dim))
        #     ego_fut_decoder.append(nn.ReLU())
        # ego_fut_decoder.append(Linear(ego_fut_dec_in_dim, 2))
        # self.ego_fut_decoder = nn.Sequential(*ego_fut_decoder)
        
        # # std layer (3 layer mlp): [256, 256, 256, 2]
        # ego_fut_std_decoder = []
        # for _ in range(self.num_reg_fcs):
        #     ego_fut_std_decoder.append(Linear(ego_fut_dec_in_dim, ego_fut_dec_in_dim))
        #     ego_fut_std_decoder.append(nn.ReLU())
        # ego_fut_std_decoder.append(Linear(ego_fut_dec_in_dim, 2))
        # ego_fut_std_decoder.append(nn.Softplus()) # Ensure positivity
        # self.ego_fut_std_decoder = nn.Sequential(*ego_fut_std_decoder)

        # # rou layer (3 layer mlp): [256, 256, 256, 1]
        # ego_fut_rou_decoder = []
        # for _ in range(self.num_reg_fcs):
        #     ego_fut_rou_decoder.append(Linear(ego_fut_dec_in_dim, ego_fut_dec_in_dim))
        #     ego_fut_rou_decoder.append(nn.ReLU())
        # ego_fut_rou_decoder.append(Linear(ego_fut_dec_in_dim, 1))
        # ego_fut_rou_decoder.append(nn.Tanh()) # Constrain to [-1, 1]
        # self.ego_fut_rou_decoder = nn.Sequential(*ego_fut_rou_decoder)
        ###############################################################


        self.navi_embedding = nn.Embedding(3, self.embed_dims)
        self.navi_se = SELayer(self.embed_dims)

        # self.way_point = nn.Embedding(self.ego_fut_mode*self.fut_ts, self.embed_dims * 2)
        self.tokenlearner = TokenLearnerV11(self.num_scenes, self.embed_dims * 2)

        self.latent_decoder = build_transformer_layer_sequence(self.latent_decoder)
        # self.way_decoder = build_transformer_layer_sequence(self.way_decoder)

        self.action_mln = MLN(self.fut_ts*2)
        self.pos_mln = MLN(self.fut_ts*2)

    def init_weights(self):
        """Initialize weights of the DeformDETR head."""
        self.transformer.init_weights()
        if self.latent_decoder is not None:
            for p in self.latent_decoder.parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)

        self.actor.init_weights()
        
        # if self.way_decoder is not None:
        #     for p in self.way_decoder.parameters():
        #         if p.dim() > 1:
        #             nn.init.xavier_uniform_(p)
        # if self.auto_regression_attention is not None:
        #     for p in self.auto_regression_attention.parameters():
        #         if p.dim() > 1:
        #             nn.init.xavier_uniform_(p)

    # @auto_fp16(apply_to=('mlvl_feats'))
    @force_fp32(apply_to=('mlvl_feats', 'prev_bev'))
    def forward(self,
                mlvl_feats,
                img_metas,
                prev_bev=None,
                only_bev=False,
                also_latent_state=False,
                ego_his_trajs=None,
                ego_lcf_feat=None,
                cmd=None,
            ):
        """Forward function.
        Args:
            mlvl_feats (tuple[Tensor]): Features from the upstream
                network, each is a 5D-tensor with shape
                (B, N, C, H, W).
            prev_bev: previous bev featues
            only_bev: only compute BEV features with encoder. 
        Returns:
            all_cls_scores (Tensor): Outputs from the classification head, \
                shape [nb_dec, bs, num_query, cls_out_channels]. Note \
                cls_out_channels should includes background.
            all_bbox_preds (Tensor): Sigmoid outputs from the regression \
                head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
                Shape [nb_dec, bs, num_query, 9].
        """
        
        bs, num_cam, _, _, _ = mlvl_feats[0].shape
        dtype = mlvl_feats[0].dtype


        bev_queries = self.bev_embedding.weight.to(dtype)

        bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
                               device=bev_queries.device).to(dtype)
        bev_pos = self.positional_encoding(bev_mask).to(dtype)

        bev_embed = self.transformer.get_bev_features(
                mlvl_feats,
                bev_queries,
                self.bev_h,
                self.bev_w,
                grid_length=(self.real_h / self.bev_h,
                             self.real_w / self.bev_w),
                bev_pos=bev_pos,
                img_metas=img_metas,
                prev_bev=prev_bev,
            )
        if only_bev and not also_latent_state:
            return bev_embed

        pos_embd = bev_pos.flatten(2).permute(0, 2, 1)
        cmd = cmd[0, 0, 0]
        cmd_idx = torch.nonzero(cmd)[0, 0]

        navi_embed = self.navi_embedding.weight[cmd_idx][None, None]
        bev_navi_embed = self.navi_se(bev_embed, navi_embed)

        bev_query = torch.cat((bev_navi_embed, pos_embd), -1)

        learned_latent_query, selected = self.tokenlearner(bev_query)

        learned_latent_query=learned_latent_query.permute(1, 0, 2)
        latent_query, latent_pos = torch.split(
            learned_latent_query, self.embed_dims, dim=2)

        latent_query = self.latent_decoder(
                query=latent_query,
                key=latent_query,
                value=latent_query,
                query_pos=latent_pos,
                key_pos=latent_pos
        )

        if only_bev and also_latent_state:
            return bev_embed, latent_pos, latent_query
        
        if self.reward_model_augmentation:
            trajectory_distribution, policy, explore_policy = self.actor(latent_query, latent_pos, cmd_idx, require_explore_policy=True)
        else:
            trajectory_distribution, policy = self.actor(latent_query, latent_pos, cmd_idx)

        outputs_ego_trajs = trajectory_distribution.mean
        
        # get the future trajectory of ego vehicle
        # sample the future trajectory of ego vehicle from the distribution
        outputs_ego_trajs_sample = trajectory_distribution.rsample() # (1, 3, 6, 2)
        outputs_ego_trajs_sample_fut = outputs_ego_trajs_sample[:,cmd_idx,...] # (1, 6, 2)

        # done! we should do experiment
        # A) use the mean vector for world model
        # B) use the sampled vector for world model
        # wp_vector = outputs_ego_trajs_fut.reshape(-1)
        wp_vector = outputs_ego_trajs_sample_fut.reshape(-1)

        wp_vector = wp_vector.unsqueeze(0).unsqueeze(0)
        # input of action mln: latent_query: (16, 1, 256) wp_vector: (1, 1, 6*2)
        act_query = self.action_mln(latent_query, wp_vector)
        # latent_pos: (16, 1, 256)
        act_pos = self.pos_mln(latent_pos, wp_vector)

        outs = {
            'selected': selected,
            'bev_embed': bev_embed,
            'scene_query': latent_query,
            'scene_pos': latent_pos,
            'act_query': act_query,
            'act_pos': act_pos,
            'ego_fut_preds': outputs_ego_trajs,
            'ego_fut_preds_sample': outputs_ego_trajs_sample_fut,
            'trajectory_distribution': trajectory_distribution,
            'policy': policy
        }

        if self.reward_model_augmentation:
            with torch.no_grad():
                # outputs_ego_trajs[:,cmd_idx,...] (Tensor): [B, T, 2]
                # latent_query (Tensor): [16, B, D]
                # self.way_point_perturbation (Tensor): [24, T, 2]
                # self.way_point_perturbation = self.way_point_perturbation.to(wp_vector.device)
                B = latent_query.size(1)

                explore_trajectory = explore_policy.sample((24,)) # [24, B, T, 2]

                explore_trajectory = rearrange(explore_trajectory, 'n b t xy -> b n t xy', n=24, b=B)

                outs['explore_policy'] = explore_policy
                outs['explore_trajectory'] = explore_trajectory

                latent_query_expanded = latent_query.unsqueeze(2).repeat(1, 1, 24, 1).flatten(1,2) # [16, B, D] => [16, B, 24, D] => [16, B*24, D]

                explore_trajectory = rearrange(explore_trajectory, 'b n t xy -> 1 (b n) (t xy)', b=B, t=6) # [1, B*24, T*2]

                act_query_reward_aug = self.action_mln(latent_query_expanded, explore_trajectory) # [16, B*24, D]
                outs['act_query_reward_aug'] = act_query_reward_aug

                # [16, B, D] => [16, B, 24, D] => [16, B*24, D]
                latent_pos_expanded = latent_pos.unsqueeze(2).repeat(1, 1, 24, 1).flatten(1,2) # [16, B, D] => [16, B, 24, D] => [16, B*24, D]
                # [16, B*24, D] + [1, B*24, T*2] => [16, B*24, D]
                act_pos_reward_aug = self.pos_mln(latent_pos_expanded, explore_trajectory)
                outs['act_pos_reward_aug'] = act_pos_reward_aug

        return outs

    def get_exploration_data(self, latent_query, latent_pos, policy, n_sample=16):
        """
        Args:
        latent_query (Tensor): [16, B, 256]
        latent_pos (Tensor): [16, B, 256]
        policy (MultiVariantNormal[Tensor]): [B, T, 2]
        Return:
        explore_trajs (Tensor): [B, N, T, 2]
        explore_act_query (Tensor): [16, B*N, 256]
        explore_act_pos (Tensor): [16, B*N, 256]
        """
        explore_trajs = policy.sample((n_sample,)) # [N, B, T, 2]
        n, b, t, xy = explore_trajs.shape
        explore_trajs = rearrange(explore_trajs, 'n b t xy -> 1 (b n) (t xy)', n=n, b=b, t=t, xy=xy)

        explore_act_query = self.action_mln(latent_query, explore_trajs)
        explore_act_pos = self.pos_mln(latent_pos, explore_trajs)
        explore_trajs = rearrange(explore_trajs, '1 (b n) (t xy) -> b n t xy', b=b, n=n, t=t, xy=xy)

        return explore_trajs, explore_act_query, explore_act_pos

    def get_bev_selected(self, bev_embed, ego_fut_cmd_idx):
        bs = bev_embed.size(0)
        dtype = bev_embed.dtype
        bev_queries = self.bev_embedding.weight.to(dtype)

        bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
                               device=bev_queries.device).to(dtype)
        bev_pos = self.positional_encoding(bev_mask).to(dtype)

        pos_embd = bev_pos.flatten(2).permute(0, 2, 1)

        navi_embed = self.navi_embedding.weight[ego_fut_cmd_idx][None, None]
        bev_navi_embed = self.navi_se(bev_embed, navi_embed)

        bev_query = torch.cat((bev_navi_embed, pos_embd), -1)

        _, selected = self.tokenlearner(bev_query)

        return selected


    @force_fp32(apply_to=('preds_dicts'))
    def loss(self,
             gt_bboxes_list,#gt_bboxes_3d
             gt_labels_list,#gt_labels_3d
             map_gt_bboxes_list,#map_gt_bboxes_3d
             map_gt_labels_list,#map_gt_labels_3d
             preds_dicts,#outs
             ego_fut_gt,#ego_fut_trajs
             ego_fut_masks,#ego_fut_masks
             ego_fut_cmd,#ego_fut_cmd
             gt_attr_labels,#gt_attr_labels
             gt_bboxes_ignore=None,
             map_gt_bboxes_ignore=None,
             img_metas=None):
        """"Loss function.
        Args:

            gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
                with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels_list (list[Tensor]): Ground truth class indices for each
                image with shape (num_gts, ).
            preds_dicts:
                all_cls_scores (Tensor): Classification score of all
                    decoder layers, has shape
                    [nb_dec, bs, num_query, cls_out_channels].
                all_bbox_preds (Tensor): Sigmoid regression
                    outputs of all decode layers. Each is a 4D-tensor with
                    normalized coordinate format (cx, cy, w, h) and shape
                    [nb_dec, bs, num_query, 4].
                enc_cls_scores (Tensor): Classification scores of
                    points on encode feature map , has shape
                    (N, h*w, num_classes). Only be passed when as_two_stage is
                    True, otherwise is None.
                enc_bbox_preds (Tensor): Regression results of each points
                    on the encode feature map, has shape (N, h*w, 4). Only be
                    passed when as_two_stage is True, otherwise is None.
            gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
                which can be ignored for each image. Default None.
        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """

        ego_fut_preds = preds_dicts['ego_fut_preds']
        trajectory_distribution = preds_dicts['trajectory_distribution']
        loss_dict = dict()
        reward_dict = dict()

        # Planning Loss
        ego_fut_gt = ego_fut_gt.squeeze(1) # (B, T, 2)
        ego_fut_masks = ego_fut_masks.squeeze(1).squeeze(1) #(B, T)
        ego_fut_cmd = ego_fut_cmd.squeeze(1).squeeze(1)

        ego_fut_gt = ego_fut_gt.unsqueeze(1).repeat(1, self.ego_fut_mode, 1, 1) # (B, mode, T, 2)
        
        loss_plan_l1_weight = ego_fut_cmd[..., None, None] * ego_fut_masks[:, None, :, None]
        loss_plan_l1_weight = loss_plan_l1_weight.repeat(1, 1, 1, 2)

        # the multi-variante normal distribution has a mean vector with shape (B, 3, T, 2)
        # input a vector with shape (B, 3, T, 2) to get the log probility of the distribution
        # the returned log probility has a shape of (B, 3, T)
        valid_mask = ego_fut_cmd[..., None] * ego_fut_masks[:, None, :]

        loss_plan_l1 = self.loss_plan_reg(
            ego_fut_preds,
            ego_fut_gt,
            loss_plan_l1_weight
        )
        loss_dict['loss_plan_reg'] = loss_plan_l1

        # Compute number of valid points for averaging
        valid_count = valid_mask.sum(dim=(1, 2)).clamp(min=1)  # shape: [B]

        imitation_nll = -trajectory_distribution.log_prob(ego_fut_gt.detach()) # [B, 3, T]
        imitation_nll = imitation_nll * valid_mask
        # sum the log probility along the direction dimension
        # reduce the direction dimension (B, 3, T) => (B,)
        imitation_nll = imitation_nll.sum(dim=(1,2))

        # the loss is negative log probility, average the loss in each minibatch
        imitation_nll = imitation_nll / valid_count
        imitation_nll = imitation_nll.mean()

        loss_dict['loss_imitation_nll'] = 0.5 * imitation_nll

        # obtain ground truth reward
        device = ego_fut_preds.device
        # map_classes=['divider','ped_crossing','boundary']
        # here we save these scene information as the class's member variable
        # which will used in the compare of rl_actor and il_actor
        self.map_gt, self.map_type_gt = obtain_map_information(map_gt_bboxes_list, map_gt_labels_list, num_classes=3, device=device)
        self.agent_cur_position, self.agent_type = obtain_curent_traffic_object_information(gt_bboxes_list, gt_labels_list, device=device) # (B, N, 2), (B, N, 10)
        self.agent_fut_position, self.agent_fut_mask = obtain_future_traffic_object_information(gt_attr_labels, fut_ts=6, device=device) # (B, N, T, 2), (B, N, T)

        ego_fut_preds_sample = preds_dicts['ego_fut_preds_sample'] # (B, 3, T, 6)

        imitation_reward_gt = self.obtain_imitation_reward_gt(
            # ego_fut_preds,
            ego_fut_preds_sample,
            ego_fut_gt,
            ego_fut_masks,
            ego_fut_cmd,            
        )

        #loss_plan_bound, loss_plan_col, loss_plan_dir = self.loss_planning(
        critic_reward_gt = self.obtain_critic_reward_gt(
            # ego_fut_preds,
            ego_fut_preds_sample,
            ego_fut_gt,
            ego_fut_masks,
            ego_fut_cmd,
            # map_gt,
            # map_type_gt,
            # agent_cur_position,
            # agent_fut_position,
            # agent_type,
            # agent_fut_mask
        )
        
        intrinsic_reward_gt = self.obtain_intrinsic_reward_gt(
            # map_gt,
            # map_type_gt,
            # agent_cur_position,
            # agent_type,
            gt_attr_labels,
        )
        
        reward_dict['imitation_reward'] = imitation_reward_gt # (B,)
        reward_dict['critic_reward'] = critic_reward_gt # (B,)
        reward_dict['intrinsic_reward'] = intrinsic_reward_gt # (B,)

        if self.reward_model_augmentation:
            with torch.no_grad():
                explore_trajectory = preds_dicts['explore_trajectory'].detach() # [B, 24, T, 2]
                # here, we are trying to add gt trajectories into the reward augment stage
                # just say to the reward model: hey! this is the full mark, remember it!
                if ego_fut_masks.all():
                    explore_trajectory = torch.cat([explore_trajectory, ego_fut_gt[ego_fut_cmd==1].unsqueeze(1)], dim=1) # [B, 24, T, 2] => [B, 25, T, 2]
                    preds_dicts["explore_trajectory"] = explore_trajectory
                    latent_query = preds_dicts['scene_query'].detach() # [16, B, D]
                    latent_pos = preds_dicts['scene_pos'].detach() # [16, B, D]

                    gt_wp_vector = ego_fut_gt[ego_fut_cmd==1].flatten().unsqueeze(0).unsqueeze(0) # [1, 1, 12]

                    act_query_gt_reward_aug = self.action_mln(latent_query, gt_wp_vector) # [16, B, D]
                    act_pos_gt_reward_aug = self.pos_mln(latent_pos, gt_wp_vector) # [16, B, D]
                    # [16, B*24, D] => [16, B*25, D]
                    preds_dicts["act_query_reward_aug"] = torch.cat([preds_dicts["act_query_reward_aug"], act_query_gt_reward_aug], dim=1)
                    preds_dicts["act_pos_reward_aug"] = torch.cat([preds_dicts["act_pos_reward_aug"], act_pos_gt_reward_aug], dim=1)

                imitation_reward_augment_gt = self.obtain_imitation_reward_gt(
                    explore_trajectory, # [B, 24 or 25, T, 2]
                    ego_fut_gt,                 # [B, T, 2]
                    ego_fut_masks,              # [B, T]
                    ego_fut_cmd,
                    reward_augment=True
                )
                critic_reward_augment_gt = self.obtain_critic_reward_gt(
                    explore_trajectory, # [B, 24 or 25, T, 2]
                    ego_fut_gt,                 # [B, T, 2]
                    ego_fut_masks,              # [B, T]
                    ego_fut_cmd,                # [B, 3]
                    # map_gt,
                    # map_type_gt,
                    # agent_cur_position,
                    # agent_fut_position,
                    # agent_type,
                    # agent_fut_mask,
                    reward_augment=True
                )

                reward_dict['imitation_reward_augment_gt'] = imitation_reward_augment_gt
                reward_dict['critic_reward_augment'] = critic_reward_augment_gt

        return loss_dict, reward_dict

    def obtain_imitation_reward_gt(
            self,
            ego_fut_preds,
            ego_fut_gt,
            ego_fut_masks,
            ego_fut_cmd,
            reward_augment=False
        ):
        if ego_fut_gt.dim() == 4:
            ego_fut_gt = ego_fut_gt[ego_fut_cmd==1]
        if not reward_augment:
            if ego_fut_preds.dim() == 4:
                # if the input fut_preds havd the mode dimention
                ego_fut_preds = ego_fut_preds[ego_fut_cmd==1]
        else:
            assert ego_fut_gt.size(0) == 1, "for critic reward, only batch_size = 1 is supported"
            B = ego_fut_preds.size(0)
            new_B = ego_fut_preds.size(1)
            # here, we add the 24 perturbation trajectories as the batch size
            ego_fut_preds = rearrange(ego_fut_preds, 'b n_p t xy -> (b n_p) t xy', b=B)
            ego_fut_masks = ego_fut_masks.repeat(new_B, 1) #[B, T] => [new_B, T]
        
        critic_reward_imitation = self.critic_reward_imitation_op(ego_fut_preds, ego_fut_gt, ego_fut_masks)

        return critic_reward_imitation

    def obtain_critic_reward_gt(
            self,
            ego_fut_preds,
            ego_fut_gt,
            ego_fut_masks,
            ego_fut_cmd,
            # map_gt,
            # map_type_gt,
            # agent_cur_position,
            # agent_fut_position,
            # agent_type,
            # agent_fut_mask,
            reward_augment=False
        ):
        """obtain critic reward
        Args:
            ego_fut_preds (Tensor): [B, ego_fut_mode, fut_ts, 2]
            ego_fut_masks (Tensor): [B, fut_ts]
            ego_fut_cmd (Tensor): [B, ego_fut_mode]
            map_gt_bboxes_list (List): LidarLineInstance 
            map_gt_labels_list (List): index of # map_classes=['divider','ped_crossing','boundary']
            gt_bboxes_list (List): Lidar3DObject # .tensor [x, y, z, x_size, y_size, z_size, yaw, ...] .bev [x_center, y_center, x_size, y_size, yaw]
            gt_labels_list (List): index of ['car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']
            gt_attr_labels (Tensor), [B, N, 34] 34 = 12 (fut_traj) + 6 (fut_masks) + 1 (fut_goal) + 9 (fut_feat) + 6 (fut_yaw)
        Returns:
            loss_plan_reg (Tensor): planning reg loss.
            loss_plan_bound (Tensor): planning map boundary constraint loss.
            loss_plan_col (Tensor): planning col constraint loss.
            loss_plan_dir (Tensor): planning directional constraint loss. 
        """
        if ego_fut_gt.dim() == 4:
            ego_fut_gt = ego_fut_gt[ego_fut_cmd==1]
        if not reward_augment:
            if ego_fut_preds.dim() == 4:
                # if the input pred dim include the mode dimension
                ego_fut_preds = ego_fut_preds[ego_fut_cmd==1]
        else:
            assert ego_fut_gt.size(0) == 1, "for critic reward, only batch_size = 1 is supported"
            B = ego_fut_preds.size(0)
            new_B = ego_fut_preds.size(1)
            # here, we add the 24 perturbation trajectories as the batch size
            ego_fut_preds = rearrange(ego_fut_preds, 'b n_p t xy -> (b n_p) t xy', b=B)
            ego_fut_masks = ego_fut_masks.repeat(new_B, 1) #[B, T] => [new_B, T]
        # device = ego_fut_preds.device
        
        
        critic_reward_bound = self.critic_reward_bound_op(
            ego_fut_preds,
            self.map_gt,
            self.map_type_gt,
            reward_augment=reward_augment
        )

        critic_reward_col = self.critic_reward_col_op(
            ego_fut_preds,
            self.agent_cur_position,
            self.agent_type,
            self.agent_fut_position,
            self.agent_fut_mask
        )

        constrain_reward = 0.5 * critic_reward_bound + 0.5 * critic_reward_col

        return constrain_reward

    def obtain_intrinsic_reward_gt(
            self,
            # map_gt,
            # map_type_gt, # map_classes=['divider','ped_crossing','boundary']
            # agent_cur_position, # .tensor [x, y, z, x_size, y_size, z_size, yaw, ...] .bev [x_center, y_center, x_size, y_size, yaw]
            # agent_type,
            gt_attr_labels, # vx, vy, ?, ?, ego_w, eln, wid, v0, Kappa
        ):
        # now, we can compute the intrinsic reward
        device = gt_attr_labels[0].device

        intrinsic_reward_col = self.intrinsic_reward_col_op(self.agent_cur_position, self.agent_type)
        intrinsic_reward_bound = self.intrinsic_reward_bound_op(self.map_gt, self.map_type_gt)

        return (intrinsic_reward_col + intrinsic_reward_bound) / 2


def obtain_map_information(map_gt_bboxes_list, map_gt_labels_list, num_classes, device):
    '''
    Here, we will obtain the map information
    map_gt      (batch_size, n_lines, n_points, 2 [x,y])
    map_type_gt (batch_size, n_lines)
    '''
    map_gt_pts_list = [map_gt_bboxes.fixed_num_sampled_points for map_gt_bboxes in map_gt_bboxes_list]

    B = len(map_gt_pts_list)
    assert B == 1, "obtain_map_information now can only support batch_size = 1"
    max_vecs = max([bboxes.shape[0] for bboxes in map_gt_pts_list])
    num_pts = map_gt_pts_list[0].shape[1]

    map_gt = torch.zeros((B, max_vecs, num_pts, 2), dtype=torch.float32, device=device)
    map_type_gt = torch.zeros((B, max_vecs), device=device)

    for b in range(B):
        bboxes = map_gt_pts_list[b]               # [num_vec, num_pts, 2]
        labels = map_gt_labels_list[b]            # [num_vec]
        num_vec = bboxes.shape[0]
        map_gt[b, :num_vec] = bboxes
        # label → one-hot
        map_type_gt[b, :num_vec] = labels
    return map_gt, map_type_gt

def obtain_curent_traffic_object_information(gt_bboxes_list, gt_labels_list, device):
    '''
    Here, we obtain the traffic object's position here
    agent_cur_position (batch_size, num_agent, 2 [x+y])
    agent_type         (batch_size, num_agent), the 10 means ['car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone']
    '''
    B = len(gt_bboxes_list)
    num_agent = max([bboxes.tensor.shape[0] for bboxes in gt_bboxes_list]) 
 
    agent_cur_position = torch.zeros((B, num_agent, 2), dtype=torch.float32, device=device)
    agent_type = torch.zeros((B, num_agent), device=device)

    for b in range(B):
        bboxes = gt_bboxes_list[b]
        agent_pos = bboxes.tensor[:, 0:2]

        agent_cur_position[b, :] = agent_pos

        labels = gt_labels_list[b]  # [num_agent_b]
        agent_type[b, :] = labels

    return agent_cur_position, agent_type

def obtain_future_traffic_object_information(gt_attr_labels, fut_ts, device):
    '''
    Obtain future traffic object information
    gt_attr_labels (List[Tensor]): [N, 34] 34 = 12 (fut_traj) + 6 (fut_masks) + 1 (fut_goal) + 9 (fut_feat) + 6 (fut_yaw)
    '''
    B = len(gt_attr_labels)
    num_agent, _ = gt_attr_labels[0].shape

    # initialize
    agent_fut_position = torch.zeros((B, num_agent, fut_ts, 2), dtype=torch.float32, device=device)
    agent_fut_mask = torch.zeros((B, num_agent, fut_ts), device=device)

    for b in range(B):
        gt_attr_labels_b = gt_attr_labels[b]
        agent_fut_position[b,:,:,:] = gt_attr_labels_b[:, :fut_ts*2].reshape(-1, fut_ts, 2) # (B, N, T, 2)
        agent_fut_mask[b,:,:] = gt_attr_labels_b[:, fut_ts*2:fut_ts*3]

    return agent_fut_position, agent_fut_mask
