# ================ base config ===================
version = 'mini'
version = 'trainval'
# length = {'trainval': 28130, 'mini': 323}
length = {'trainval': 120000, 'mini': 323}

plugin = True
plugin_dir = "projects/mmdet3d_plugin/"
dist_params = dict(backend="nccl")
log_level = "INFO"
work_dir = None

total_batch_size = 4 #64 #10#24 #64 #40 #96 #10 #120 #40 #20 
num_gpus = 1 #2 # 6  # 8
batch_size = total_batch_size // num_gpus
num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
num_epochs = 100
checkpoint_epoch_interval = 50 #20

checkpoint_config = dict(
    interval=num_iters_per_epoch * checkpoint_epoch_interval
)
log_config = dict(
    interval=51,
    hooks=[
        dict(type="TextLoggerHook", by_epoch=False),
        dict(type="TensorboardLoggerHook"),
    ],
)
load_from = None #"/mnt/private-user-data/ed/ATDRIVE_stage1.pth" #None
resume_from = None
# workflow = [("train", 1)]
workflow = [("train", 3), ("train", 3)]
fp16 = dict(loss_scale=32.0)
input_shape = (704, 256)
# input_shape = (1600, 928)

point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
voxel_size = [0.2, 0.2, 8]
patch_size = [102.4, 102.4]
# img_norm_cfg = dict(mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

_dim_ = 256
_pos_dim_ = _dim_ // 2
_ffn_dim_ = _dim_ * 2
_num_levels_ = 4
bev_h_ = 200
bev_w_ = 200
_feed_dim_ = _ffn_dim_
_dim_half_ = _pos_dim_
canvas_size = (bev_h_, bev_w_)
queue_length = 3  # each sequence contains `queue_length` frames.

NameMapping = {
    #=================vehicle=================
    # bicycle
    'vehicle.bh.crossbike': 'bicycle',
    "vehicle.diamondback.century": 'bicycle',
    "vehicle.gazelle.omafiets": 'bicycle',
    # car
    "vehicle.audi.etron": 'car',
    "vehicle.chevrolet.impala": 'car',
    "vehicle.dodge.charger_2020": 'car',
    "vehicle.dodge.charger_police": 'car',
    "vehicle.dodge.charger_police_2020": 'car',
    "vehicle.lincoln.mkz_2017": 'car',
    "vehicle.lincoln.mkz_2020": 'car',
    "vehicle.mini.cooper_s_2021": 'car',
    "vehicle.mercedes.coupe_2020": 'car',
    "vehicle.ford.mustang": 'car',
    "vehicle.nissan.patrol_2021": 'car',
    "vehicle.audi.tt": 'car',
    "vehicle.audi.etron": 'car',
    "vehicle.ford.crown": 'car',
    "vehicle.ford.mustang": 'car',
    "vehicle.tesla.model3": 'car',
    "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/FordCrown/SM_FordCrown_parked.SM_FordCrown_parked": 'car',
    "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/Charger/SM_ChargerParked.SM_ChargerParked": 'car',
    "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/Lincoln/SM_LincolnParked.SM_LincolnParked": 'car',
    "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/MercedesCCC/SM_MercedesCCC_Parked.SM_MercedesCCC_Parked": 'car',
    "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/Mini2021/SM_Mini2021_parked.SM_Mini2021_parked": 'car',
    "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/NissanPatrol2021/SM_NissanPatrol2021_parked.SM_NissanPatrol2021_parked": 'car',
    "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/TeslaM3/SM_TeslaM3_parked.SM_TeslaM3_parked": 'car',
    "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/VolkswagenT2/SM_VolkswagenT2_2021_Parked.SM_VolkswagenT2_2021_Parked": 'car',
    # bus
    # van
    "/Game/Carla/Static/Car/4Wheeled/ParkedVehicles/VolkswagenT2/SM_VolkswagenT2_2021_Parked.SM_VolkswagenT2_2021_Parked": "van",
    "vehicle.ford.ambulance": "van",
    # truck
    "vehicle.carlamotors.firetruck": 'truck',
    #=========================================

    #=================traffic sign============
    # traffic.speed_limit
    "traffic.speed_limit.30": 'traffic_sign',
    "traffic.speed_limit.40": 'traffic_sign',
    "traffic.speed_limit.50": 'traffic_sign',
    "traffic.speed_limit.60": 'traffic_sign',
    "traffic.speed_limit.90": 'traffic_sign',
    "traffic.speed_limit.120": 'traffic_sign',
    
    "traffic.stop": 'traffic_sign',
    "traffic.yield": 'traffic_sign',
    "traffic.traffic_light": 'traffic_light',
    #=========================================

    #===================Construction===========
    "static.prop.warningconstruction" : 'traffic_cone',
    "static.prop.warningaccident": 'traffic_cone',
    "static.prop.trafficwarning": "traffic_cone",

    #===================Construction===========
    "static.prop.constructioncone": 'traffic_cone',

    #=================pedestrian==============
    "walker.pedestrian.0001": 'pedestrian',
    "walker.pedestrian.0003": 'pedestrian',
    "walker.pedestrian.0004": 'pedestrian',
    "walker.pedestrian.0005": 'pedestrian',
    "walker.pedestrian.0007": 'pedestrian',
    "walker.pedestrian.0010": 'pedestrian',
    "walker.pedestrian.0013": 'pedestrian',
    "walker.pedestrian.0014": 'pedestrian',
    "walker.pedestrian.0015": 'pedestrian',
    "walker.pedestrian.0016": 'pedestrian',
    "walker.pedestrian.0017": 'pedestrian',
    "walker.pedestrian.0018": 'pedestrian',
    "walker.pedestrian.0019": 'pedestrian',
    "walker.pedestrian.0020": 'pedestrian',
    "walker.pedestrian.0021": 'pedestrian',
    "walker.pedestrian.0022": 'pedestrian',
    "walker.pedestrian.0025": 'pedestrian',
    "walker.pedestrian.0027": 'pedestrian',
    "walker.pedestrian.0030": 'pedestrian',
    "walker.pedestrian.0031": 'pedestrian',
    "walker.pedestrian.0032": 'pedestrian',
    "walker.pedestrian.0034": 'pedestrian',
    "walker.pedestrian.0035": 'pedestrian',
    "walker.pedestrian.0041": 'pedestrian',
    "walker.pedestrian.0042": 'pedestrian',
    "walker.pedestrian.0046": 'pedestrian',
    "walker.pedestrian.0047": 'pedestrian',

    # ==========================================
    "static.prop.dirtdebris01": 'others',
    "static.prop.dirtdebris02": 'others',
}

# ================== model ========================
# class_names = [
#     "car",
#     "truck",
#     "construction_vehicle",
#     "bus",
#     "trailer",
#     "barrier",
#     "motorcycle",
#     "bicycle",
#     "pedestrian",
#     "traffic_cone",
# ]


class_names = [
'car','van','truck','bicycle','traffic_sign','traffic_cone','traffic_light','pedestrian','others'
]#注意需要修改统一
# map_class_names = [
#     'ped_crossing',
#     'divider',
#     'boundary',
# ]
map_class_names = [
    'Broken', 
    'Solid', 
    'SolidSolid', 
    'Other', 
    'NONE', 
    'Center'
]#注意需要修改统一
num_classes = len(class_names)#注意需要修改统一
num_map_classes = len(map_class_names)#注意需要修改统一
roi_size = (30, 60)

num_sample = 20
fut_ts = 12
fut_mode = 6
ego_fut_ts = 6
ego_fut_mode = 6
queue_length = 4 # history + current

embed_dims = 256
num_groups = 8
num_decoder = 6
num_single_frame_decoder = 1
num_single_frame_decoder_map = 1
use_deformable_func = True  # mmdet3d_plugin/ops/setup.py needs to be executed
strides = [4, 8, 16, 32]
num_levels = len(strides)
num_depth_layers = 3
drop_out = 0.1
temporal = True
temporal_map = True
decouple_attn = True
decouple_attn_map = True # False
decouple_attn_motion = True
with_quality_estimation = True

task_config = dict(
    with_det=True,
    with_map=True,
    with_motion_plan=False,
)

### traj prediction args ###
predict_steps = 12
predict_modes = 6
fut_steps = 4
past_steps = 4
use_nonlinear_optimizer = True

## occflow setting	
occ_n_future = 4	
occ_n_future_plan = 6
occ_n_future_max = max([occ_n_future, occ_n_future_plan])	

### Occ args ### 
occflow_grid_conf = {
    'xbound': [-50.0, 50.0, 0.5],
    'ybound': [-50.0, 50.0, 0.5],
    'zbound': [-10.0, 10.0, 20.0],
}

model = dict(
    type="ATDRIVE",
    use_grid_mask=True,
    use_deformable_func=use_deformable_func,
    img_backbone=dict(
        type="ResNet",
        depth=50,
        num_stages=4,
        frozen_stages=-1,
        norm_eval=False,
        style="pytorch",
        with_cp=True,
        out_indices=(0, 1, 2, 3),
        norm_cfg=dict(type="BN", requires_grad=True),
        # pretrained="ckpt/resnet50-19c8e357.pth",
        pretrained="resnet50-0676ba61.pth",
    ),
    img_neck=dict(
        type="FPN",
        num_outs=num_levels,
        start_level=0,
        out_channels=embed_dims,
        add_extra_convs="on_output",
        relu_before_extra_convs=True,
        in_channels=[256, 512, 1024, 2048],
    ),
    depth_branch=dict(  # for auxiliary supervision only
        type="DenseDepthNet",
        embed_dims=embed_dims,
        num_depth_layers=num_depth_layers,
        loss_weight=0.2,
    ),
    head=dict(
        type="ATDRIVEHead",
        task_config=task_config,
        det_head=dict(
            type="Sparse4DHead",
            cls_threshold_to_reg=0.05,
            decouple_attn=decouple_attn,
            instance_bank=dict(
                type="InstanceBank",
                num_anchor=900,
                num_anchor_map=100,
                embed_dims=embed_dims,
                anchor="data/kmeans/kmeans_det_900.npy",
                anchor_map="data/kmeans/kmeans_map_100.npy",
                anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
                anchor_handler_map=dict(type="SparsePoint3DKeyPointsGenerator"),
                num_temp_instances=600 if temporal else -1,
                num_temp_instances_map=50 if temporal_map else -1,
                # num_temp_instances=0 if temporal else -1,
                # num_temp_instances_map=0 if temporal_map else -1,
                confidence_decay=0.6,
                feat_grad=False, #False, #True, #False,
            ),
            anchor_encoder=dict(
                type="SparseBox3DEncoder",
                vel_dims=3,
                embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
                mode="cat" if decouple_attn else "add",
                output_fc=not decouple_attn,
                in_loops=1,
                out_loops=4 if decouple_attn else 2,
            ),
            anchor_encoder_map=dict(
                type="SparsePoint3DEncoder",
                embed_dims=embed_dims,
                num_sample=num_sample,
            ),
            num_single_frame_decoder=num_single_frame_decoder,
            operation_order=(
                [
                    "gnn",
                    "norm",
                    "deformable",
                    "ffn",
                    "norm",
                    "refine",
                ]
                * num_single_frame_decoder
                + [
                    "temp_gnn",
                    "gnn",
                    "norm",
                    "deformable",
                    "ffn",
                    "norm",
                    "refine",
                ]
                * (num_decoder - num_single_frame_decoder)
            )[2:],#[:], #[2:], origin
            temp_graph_model=dict(
                type="MultiheadFlashAttention",
                embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
                num_heads=num_groups,
                batch_first=True,
                dropout=drop_out,
            )
            if temporal
            else None,
            graph_model=dict(
                type="MultiheadFlashAttention",
                embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
                num_heads=num_groups,
                batch_first=True,
                dropout=drop_out,
            ),
            norm_layer=dict(type="LN", normalized_shape=embed_dims),
            ffn=dict(
                type="AsymmetricFFN",
                in_channels=embed_dims * 2,
                pre_norm=dict(type="LN"),
                embed_dims=embed_dims,
                feedforward_channels=embed_dims * 4,
                num_fcs=2,
                ffn_drop=drop_out,
                act_cfg=dict(type="ReLU", inplace=True),
            ),
            deformable_model=dict(
                type="DeformableFeatureAggregation",
                embed_dims=embed_dims,
                num_groups=num_groups,
                num_levels=num_levels,
                num_cams=6,
                attn_drop=0.15,
                use_deformable_func=use_deformable_func,
                use_camera_embed=True,
                residual_mode="cat",
                kps_generator=dict(
                    type="SparseBox3DKeyPointsGenerator",
                    num_learnable_pts=6,
                    fix_scale=[
                        [0, 0, 0],
                        [0.45, 0, 0],
                        [-0.45, 0, 0],
                        [0, 0.45, 0],
                        [0, -0.45, 0],
                        [0, 0, 0.45],
                        [0, 0, -0.45],
                    ],
                ),
                kps_generator_map=dict(
                    type="SparsePoint3DKeyPointsGenerator",
                    embed_dims=embed_dims,
                    num_sample=num_sample,
                    num_learnable_pts=3,
                    fix_height=(0, 0.5, -0.5, 1, -1),
                    ground_height=-1.84023, # ground height in lidar frame
                ),
            ),
            refine_layer=dict(
                type="SparseBox3DRefinementModule",
                embed_dims=embed_dims,
                num_cls=num_classes,
                refine_yaw=True,
                with_quality_estimation=with_quality_estimation,
                num_sample_map=num_sample, #map
                num_cls_map=num_map_classes,#map
            ),
            sampler=dict(
                type="SparseBox3DTarget",
                num_dn_groups=0,
                num_temp_dn_groups=0,
                dn_noise_scale=[2.0] * 3 + [0.5] * 7,
                max_dn_gt=32,
                add_neg_dn=True,
                cls_weight=2.0,
                box_weight=0.25,
                reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
                cls_wise_reg_weights={
                    class_names.index("traffic_cone"): [
                        2.0,
                        2.0,
                        2.0,
                        1.0,
                        1.0,
                        1.0,
                        0.0,
                        0.0,
                        1.0,
                        1.0,
                    ],
                },
            ),
            loss_cls=dict(
                type="FocalLoss",
                use_sigmoid=True,
                gamma=2.0,
                alpha=0.25,
                loss_weight=2.0,
            ),
            loss_reg=dict(
                type="SparseBox3DLoss",
                loss_box=dict(type="L1Loss", loss_weight=0.25),
                loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
                loss_yawness=dict(type="GaussianFocalLoss"),
                # cls_allow_reverse=[class_names.index("barrier")], 
                cls_allow_reverse=[class_names.index("traffic_cone")], 
            ),
            decoder=dict(type="SparseBox3DDecoder"),
            reg_weights=[2.0] * 3 + [1.0] * 7,
            with_instance_id=True, #False,
            with_instance_id_map=True, #False,
            # map
            sampler_map=dict(
                type="SparsePoint3DTarget",
                assigner=dict(
                    type='HungarianLinesAssigner',
                    cost=dict(
                        type='MapQueriesCost',
                        cls_cost=dict(type='FocalLossCost', weight=1.0),
                        reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
                    ),
                ),
                num_cls=num_map_classes,
                num_sample=num_sample,
                roi_size=roi_size,
            ),
            loss_cls_map=dict(
                type="FocalLoss",
                use_sigmoid=True,
                gamma=2.0,
                alpha=0.25,
                loss_weight=1.0,
            ),

            loss_reg_map=dict(
                type="SparseLineLoss",
                loss_line=dict(
                    type='LinesL1Loss',
                    loss_weight=10.0,
                    beta=0.01,
                ),
                num_sample=num_sample,
                roi_size=roi_size,
            ),
            decoder_map=dict(type="SparsePoint3DDecoder"),
            reg_weights_map=[1.0] * 40,
            gt_cls_key_map="gt_map_labels",
            gt_reg_key_map="gt_map_pts",
            gt_id_key_map="map_instance_id",
            # 
            # task_prefix='map',
        ),
        # map_head=dict(
        #     type="Sparse4DHead",
        #     cls_threshold_to_reg=0.05,
        #     decouple_attn=decouple_attn_map,
        #     instance_bank=dict(
        #         type="InstanceBank",
        #         num_anchor=100,
        #         embed_dims=embed_dims,
        #         anchor="data/kmeans/kmeans_map_100.npy",
        #         anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
        #         num_temp_instances=0 if temporal_map else -1,
        #         confidence_decay=0.6,
        #         feat_grad=True,
        #     ),
        #     anchor_encoder=dict(
        #         type="SparsePoint3DEncoder",
        #         embed_dims=embed_dims,
        #         num_sample=num_sample,
        #     ),
        #     num_single_frame_decoder=num_single_frame_decoder_map,
        #     operation_order=(
        #         [
        #             "gnn",
        #             "norm",
        #             "deformable",
        #             "ffn",
        #             "norm",
        #             "refine",
        #         ]
        #         * num_single_frame_decoder_map
        #         + [
        #             "temp_gnn",
        #             "gnn",
        #             "norm",
        #             "deformable",
        #             "ffn",
        #             "norm",
        #             "refine",
        #         ]
        #         * (num_decoder - num_single_frame_decoder_map)
        #     )[:],
        #     temp_graph_model=dict(
        #         type="MultiheadFlashAttention",
        #         embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
        #         num_heads=num_groups,
        #         batch_first=True,
        #         dropout=drop_out,
        #     )
        #     if temporal_map
        #     else None,
        #     graph_model=dict(
        #         type="MultiheadFlashAttention",
        #         embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
        #         num_heads=num_groups,
        #         batch_first=True,
        #         dropout=drop_out,
        #     ),
        #     norm_layer=dict(type="LN", normalized_shape=embed_dims),
        #     ffn=dict(
        #         type="AsymmetricFFN",
        #         in_channels=embed_dims * 2,
        #         pre_norm=dict(type="LN"),
        #         embed_dims=embed_dims,
        #         feedforward_channels=embed_dims * 4,
        #         num_fcs=2,
        #         ffn_drop=drop_out,
        #         act_cfg=dict(type="ReLU", inplace=True),
        #     ),
        #     deformable_model=dict(
        #         type="DeformableFeatureAggregation",
        #         embed_dims=embed_dims,
        #         num_groups=num_groups,
        #         num_levels=num_levels,
        #         num_cams=6,
        #         attn_drop=0.15,
        #         use_deformable_func=use_deformable_func,
        #         use_camera_embed=True,
        #         residual_mode="cat",
        #         kps_generator=dict(
        #             type="SparsePoint3DKeyPointsGenerator",
        #             embed_dims=embed_dims,
        #             num_sample=num_sample,
        #             num_learnable_pts=3,
        #             fix_height=(0, 0.5, -0.5, 1, -1),
        #             ground_height=-1.84023, # ground height in lidar frame
        #         ),
        #     ),
        #     refine_layer=dict(
        #         type="SparsePoint3DRefinementModule",
        #         embed_dims=embed_dims,
        #         num_sample=num_sample,
        #         num_cls=num_map_classes,
        #     ),
        #     sampler=dict(
        #         type="SparsePoint3DTarget",
        #         assigner=dict(
        #             type='HungarianLinesAssigner',
        #             cost=dict(
        #                 type='MapQueriesCost',
        #                 cls_cost=dict(type='FocalLossCost', weight=1.0),
        #                 reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
        #             ),
        #         ),
        #         num_cls=num_map_classes,
        #         num_sample=num_sample,
        #         roi_size=roi_size,
        #     ),
        #     loss_cls=dict(
        #         type="FocalLoss",
        #         use_sigmoid=True,
        #         gamma=2.0,
        #         alpha=0.25,
        #         loss_weight=1.0,
        #     ),
        #     loss_reg=dict(
        #         type="SparseLineLoss",
        #         loss_line=dict(
        #             type='LinesL1Loss',
        #             loss_weight=10.0,
        #             beta=0.01,
        #         ),
        #         num_sample=num_sample,
        #         roi_size=roi_size,
        #     ),
        #     decoder=dict(type="SparsePoint3DDecoder"),
        #     reg_weights=[1.0] * 40,
        #     gt_cls_key="gt_map_labels",
        #     gt_reg_key="gt_map_pts",
        #     gt_id_key="map_instance_id",
        #     with_instance_id=False,
        #     task_prefix='map',
        # ),
        motion_plan_head=dict(
            type='MotionPlanningHead',
            fut_ts=fut_ts,
            fut_mode=fut_mode,
            ego_fut_ts=ego_fut_ts,
            ego_fut_mode=ego_fut_mode,
            motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
            plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
            embed_dims=embed_dims,
            decouple_attn=decouple_attn_motion,
            instance_queue=dict(
                type="InstanceQueue",
                embed_dims=embed_dims,
                queue_length=queue_length,
                tracking_threshold=0.2,
                feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
            ),
            operation_order=(
                [
                    "temp_gnn",
                    "gnn",
                    "norm",
                    "cross_gnn",
                    "norm",
                    "ffn",                    
                    "norm",
                ] * 3 +
                [
                    "refine",
                ]
            ),
            temp_graph_model=dict(
                type="MultiheadAttention",
                embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
                num_heads=num_groups,
                batch_first=True,
                dropout=drop_out,
            ),
            graph_model=dict(
                type="MultiheadFlashAttention",
                embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
                num_heads=num_groups,
                batch_first=True,
                dropout=drop_out,
            ),
            cross_graph_model=dict(
                type="MultiheadFlashAttention",
                embed_dims=embed_dims,
                num_heads=num_groups,
                batch_first=True,
                dropout=drop_out,
            ),
            norm_layer=dict(type="LN", normalized_shape=embed_dims),
            ffn=dict(
                type="AsymmetricFFN",
                in_channels=embed_dims,
                pre_norm=dict(type="LN"),
                embed_dims=embed_dims,
                feedforward_channels=embed_dims * 2,
                num_fcs=2,
                ffn_drop=drop_out,
                act_cfg=dict(type="ReLU", inplace=True),
            ),
            refine_layer=dict(
                type="MotionPlanningRefinementModule",
                embed_dims=embed_dims,
                fut_ts=fut_ts,
                fut_mode=fut_mode,
                ego_fut_ts=ego_fut_ts,
                ego_fut_mode=ego_fut_mode,
            ),
            motion_sampler=dict(
                type="MotionTarget",
            ),
            motion_loss_cls=dict(
                type='FocalLoss',
                use_sigmoid=True,
                gamma=2.0,
                alpha=0.25,
                loss_weight=0.2
            ),
            motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
            planning_sampler=dict(
                type="PlanningTarget",
                ego_fut_ts=ego_fut_ts,
                ego_fut_mode=ego_fut_mode,
            ),
            plan_loss_cls=dict(
                type='FocalLoss',
                use_sigmoid=True,
                gamma=2.0,
                alpha=0.25,
                loss_weight=0.5,
            ),
            plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
            plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
            motion_decoder=dict(type="SparseBox3DMotionDecoder"),
            planning_decoder=dict(
                type="HierarchicalPlanningDecoder",
                ego_fut_ts=ego_fut_ts,
                ego_fut_mode=ego_fut_mode,
                use_rescore=True,
            ),
            num_det=50,
            num_map=10,
        ),
    ),
)

# ================== data ========================
# dataset_type = "NuScenes3DDataset"
# data_root = "data/nuscenes/"
# anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
# file_client_args = dict(backend="disk")
dataset_type = "B2D_E2E_Dataset"
data_root = "data/bench2drive"
info_root = "data/infos"
map_root = "data/bench2drive/maps"
map_file = "data/infos/b2d_map_infos.pkl"
file_client_args = dict(backend="disk")
ann_file_train=info_root + f"/b2d_infos_train.pkl"
ann_file_val=info_root + f"/b2d_infos_val.pkl"
ann_file_test=info_root + f"/b2d_infos_val.pkl"

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
)   # 
# train_pipeline = [
#     dict(type="LoadMultiViewImageFromFiles", to_float32=True),
#     dict(
#         type="LoadPointsFromFile",
#         coord_type="LIDAR",
#         load_dim=5,
#         use_dim=5,
#         file_client_args=file_client_args,
#     ),
#     dict(type="ResizeCropFlipImage"),
#     dict(
#         type="MultiScaleDepthMapGenerator",
#         downsample=strides[:num_depth_layers],
#     ),
#     dict(type="BBoxRotation"),
#     dict(type="PhotoMetricDistortionMultiViewImage"),
#     dict(type="NormalizeMultiviewImage", **img_norm_cfg),
#     dict(
#         type="CircleObjectRangeFilter",
#         class_dist_thred=[55] * len(class_names),
#     ),
#     dict(type="InstanceNameFilter", classes=class_names),
#     dict(
#         type='VectorizeMap',
#         roi_size=roi_size,
#         simplify=False,
#         normalize=False,
#         sample_num=num_sample,
#         permute=True,
#     ),
#     dict(type="NuScenesSparse4DAdaptor"),
#     dict(
#         type="Collect",
#         keys=[
#             "img",
#             "timestamp",
#             "projection_mat",
#             "image_wh",
#             "gt_depth",
#             "focal",
#             "gt_bboxes_3d",
#             "gt_labels_3d",
#             'gt_map_labels', 
#             'gt_map_pts',
#             'gt_agent_fut_trajs',
#             'gt_agent_fut_masks',
#             'gt_ego_fut_trajs',
#             'gt_ego_fut_masks',
#             'gt_ego_fut_cmd',
#             'ego_status',

#             # 'map_infos',
#             'scene_name',
#             'sample_idx',
#             'prev',
#             'next',
#             # 'map_instance_id', #map_instance_id=info['track_globle_id'],
            

#         ],
#         meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id", "map_instance_id"],
#     ),
# ]
# test_pipeline = [
#     dict(type="LoadMultiViewImageFromFiles", to_float32=True),
#     dict(type="ResizeCropFlipImage"),
#     dict(type="NormalizeMultiviewImage", **img_norm_cfg),
#     dict(type="NuScenesSparse4DAdaptor"),
#     dict(
#         type="Collect",
#         keys=[
#             "img",
#             "timestamp",
#             "projection_mat",
#             "image_wh",
#             'ego_status',
#             'gt_ego_fut_cmd',
#         ],
#         meta_keys=["T_global", "T_global_inv", "timestamp"],
#     ),
# ]
# eval_pipeline = [
#     dict(
#         type="CircleObjectRangeFilter",
#         class_dist_thred=[55] * len(class_names),
#     ),
#     dict(type="InstanceNameFilter", classes=class_names),
#     dict(
#         type='VectorizeMap',
#         roi_size=roi_size,
#         simplify=True,
#         normalize=False,
#     ),
#     dict(
#         type='Collect', 
#         keys=[
#             'vectors',
#             "gt_bboxes_3d",
#             "gt_labels_3d",
#             'gt_agent_fut_trajs',
#             'gt_agent_fut_masks',
#             'gt_ego_fut_trajs',
#             'gt_ego_fut_masks', 
#             'gt_ego_fut_cmd',
#             'fut_boxes'
#         ],
#         meta_keys=['token', 'timestamp']
#     ),
# ]
train_pipeline = [
    dict(type="LoadMultiViewImageFromFilesInCeph", to_float32=True, file_client_args=file_client_args, img_root=data_root),
    dict(type="PhotoMetricDistortionMultiViewImage"),
    dict(type="ResizeCropFlipImage"), #for augmentation zdp
    dict(
        type="LoadAnnotations3D_E2E",
        with_bbox_3d=True,
        with_label_3d=True,
        with_attr_label=False,
        with_vis_token=False,
        with_future_anns=True,  # occ_flow gt
        with_ins_inds_3d=True,  # ins_inds 
        ins_inds_add_1=True,    # ins_inds start from 1
    ),

    # dict(type='GenerateOccFlowLabels', 
    #      grid_conf=occflow_grid_conf, 
    #      ignore_index=255, 
    #      only_vehicle=True, 
    #      filter_invisible=False,
    #      all_classes = class_names,
    #      vehicle_classes = ['car','van','truck','bicycle'],
    #      plan_classes = ['car','van','truck','bicycle','pedestrian'],
    #      ),

    dict(type="ObjectRangeFilterTrack", point_cloud_range=point_cloud_range),
    dict(type="ObjectNameFilterTrack", classes=class_names),
    dict(type="NormalizeMultiviewImage", **img_norm_cfg),
    dict(type="PadMultiViewImage", size_divisor=32),
    dict(type="DefaultFormatBundle3D", class_names=class_names),
    # add
    dict(
        type='VectorizeMapCarla',
        roi_size=roi_size,
        simplify=False,
        normalize=False,
        sample_num=num_sample,
        permute=True,
    ),
    dict(
        type="MultiScaleDepthMapGeneratorCarla",
        downsample=strides[:num_depth_layers],
    ),

    dict(
        type="CustomCollect3D",
        keys=[
            "gt_bboxes_3d",
            "gt_labels_3d",
            "gt_inds",
            "img",
            "timestamp",
            "l2g_r_mat",
            "l2g_t",
            "gt_fut_traj",
            "gt_fut_traj_mask",
            "gt_past_traj",
            "gt_past_traj_mask",
            "gt_sdc_bbox",
            "gt_sdc_label",
            # "gt_sdc_fut_traj",
            # "gt_sdc_fut_traj_mask",
            "gt_ego_fut_trajs",
            "gt_ego_fut_masks",
            "gt_lane_labels",
            "gt_lane_bboxes",
            "gt_lane_masks",
             # Occ gt
            # "gt_segmentation",
            # "gt_instance", 
            # "gt_centerness", 
            # "gt_offset", 
            # "gt_flow",
            # "gt_backward_flow",
            # "gt_occ_has_invalid_frame",	
            # "gt_occ_img_is_valid",

            ## gt future bbox for plan	
            # "gt_future_boxes",	
            # "gt_future_labels",	
            # planning	
            "sdc_planning",	
            "sdc_planning_mask",	
            "command",
            # new
            "projection_mat", #lidar2img
            "gt_agent_fut_trajs",
            "gt_agent_fut_masks",
            "gt_ego_fut_cmd", # == command?
            "gt_map_pts", 
            "gt_map_labels",   
            "gt_depth",        

        ],
        meta_keys=['filename', 'ori_shape', 'img_shape', 'lidar2img',
                            'depth2img', 'cam2img', 'pad_shape',
                            'scale_factor', 'box_mode_3d', 'box_type_3d',
                            'img_norm_cfg', 'sample_idx', 'prev_idx', 'next_idx',
                            'can_bus','folder','frame_idx',

                            'T_global', 'T_global_inv',  'instance_id', 'map_instance_id', 'timestamp',
                            # lidar2global
        ],
    ),
]
test_pipeline = [
    dict(type='LoadMultiViewImageFromFilesInCeph', to_float32=True,
            file_client_args=file_client_args, img_root=data_root),
    # dict(type="PhotoMetricDistortionMultiViewImage"),
    dict(type="NormalizeMultiviewImage", **img_norm_cfg),
    dict(type="PadMultiViewImage", size_divisor=32),
    dict(type='LoadAnnotations3D_E2E', 
         with_bbox_3d=True,
         with_label_3d=True, 
         with_attr_label=False,
         with_vis_token=False,
         with_future_anns=True,
         with_ins_inds_3d=True,
         ins_inds_add_1=True, # ins_inds start from 1
         ),
    #### dict(type='GenerateOccFlowLabels', 
    #      grid_conf=occflow_grid_conf, 
    #      ignore_index=255, 
    #      only_vehicle=True, 
    #      filter_invisible=False,
    #      all_classes = class_names,
    #      vehicle_classes = ['car','van','truck','bicycle'],
    #      plan_classes = ['car','van','truck','bicycle','pedestrian'],
    ####      ),
    # dict(
    #     type="MultiScaleFlipAug3D",
    #     img_scale=(1600, 900),
    #     pts_scale_ratio=1,
    #     flip=False,
    #     transforms=[
    #         dict(
    #             type="DefaultFormatBundle3D", class_names=class_names, with_label=False
    #         ),
    #         dict(
    #             type="CustomCollect3D", keys=[
    #                                         "img",
    #                                         "timestamp",
    #                                         "l2g_r_mat",
    #                                         "l2g_t",
    #                                         "gt_lane_labels",
    #                                         "gt_lane_bboxes",
    #                                         "gt_lane_masks",
    #                                         "gt_segmentation",
    #                                         "gt_instance", 
    #                                         "gt_centerness", 
    #                                         "gt_offset", 
    #                                         "gt_flow",
    #                                         "gt_backward_flow",
    #                                         # "gt_occ_has_invalid_frame",	
    #                                         # "gt_occ_img_is_valid",	
    #                                         # planning	
    #                                         "sdc_planning",	
    #                                         "sdc_planning_mask",	
    #                                         "command",
    #                                     ]
    #         ),
    #     ],
    # ),
    dict(type="ObjectRangeFilterTrack", point_cloud_range=point_cloud_range),
    dict(type="ObjectNameFilterTrack", classes=class_names),
    dict(type="NormalizeMultiviewImage", **img_norm_cfg),
    dict(type="PadMultiViewImage", size_divisor=32),
    dict(type="DefaultFormatBundle3D", class_names=class_names),
    # add
    dict(
        type='VectorizeMapCarla',
        roi_size=roi_size,
        simplify=False,
        normalize=False,
        sample_num=num_sample,
        permute=True,
    ),
    dict(
        type="MultiScaleDepthMapGeneratorCarla",
        downsample=strides[:num_depth_layers],
    ),

    dict(
        type="CustomCollect3D",
        keys=[
            "gt_bboxes_3d",
            "gt_labels_3d",
            "gt_inds",
            "img",
            "timestamp",
            "l2g_r_mat",
            "l2g_t",
            "gt_fut_traj",
            "gt_fut_traj_mask",
            "gt_past_traj",
            "gt_past_traj_mask",
            "gt_sdc_bbox",
            "gt_sdc_label",
            # "gt_sdc_fut_traj",
            # "gt_sdc_fut_traj_mask",
            "gt_ego_fut_trajs",
            "gt_ego_fut_masks",
            "gt_lane_labels",
            "gt_lane_bboxes",
            "gt_lane_masks",
             # Occ gt
            # "gt_segmentation",
            # "gt_instance", 
            # "gt_centerness", 
            # "gt_offset", 
            # "gt_flow",
            # "gt_backward_flow",
            # "gt_occ_has_invalid_frame",	
            # "gt_occ_img_is_valid",

            ## gt future bbox for plan	
            # "gt_future_boxes",	
            # "gt_future_labels",	
            # planning	
            "sdc_planning",	
            "sdc_planning_mask",	
            "command",
            # new
            "projection_mat", #lidar2img
            "gt_agent_fut_trajs",
            "gt_agent_fut_masks",
            "gt_ego_fut_cmd", # == command?
            "gt_map_pts", 
            "gt_map_labels",   
            "gt_depth",        

        ],
        meta_keys=['filename', 'ori_shape', 'img_shape', 'lidar2img',
                            'depth2img', 'cam2img', 'pad_shape',
                            'scale_factor', 'box_mode_3d', 'box_type_3d',
                            'img_norm_cfg', 'sample_idx', 'prev_idx', 'next_idx',
                            'can_bus','folder','frame_idx',

                            'T_global', 'T_global_inv',  'instance_id', 'map_instance_id', 'timestamp',
                            # lidar2global
        ],
    ),
]

inference_only_pipeline = [
    dict(type='LoadMultiViewImageFromFilesInCeph', to_float32=True,
            file_client_args=file_client_args, img_root=data_root),
    dict(type="NormalizeMultiviewImage", **img_norm_cfg),
    dict(type="PadMultiViewImage", size_divisor=32),
    dict(
        type="MultiScaleFlipAug3D",
        img_scale=(1600, 900),
        pts_scale_ratio=1,
        flip=False,
        transforms=[
            dict(
                type="DefaultFormatBundle3D", class_names=class_names, with_label=False
            ),
            dict(
                type="CustomCollect3D", keys=[
                                            "img",
                                            "timestamp",
                                            "l2g_r_mat",
                                            "l2g_t",
                                            "command",
                                        ]
            ),
        ],
    ),
]

input_modality = dict(
    use_lidar=False,
    use_camera=True,
    use_radar=False,
    use_map=False,
    use_external=False,
)

# data_basic_config = dict(
#     type=dataset_type,
#     data_root=data_root,
#     classes=class_names,
#     map_classes=map_class_names,
#     modality=input_modality,
#     version="v1.0-trainval",
# )
# eval_config = dict(
#     **data_basic_config,
#     ann_file=anno_root + 'nuscenes_infos_val.pkl',
#     pipeline=eval_pipeline,
#     test_mode=True,
# )
# data_aug_conf = {
#     "resize_lim": (0.40, 0.47),
#     "final_dim": input_shape[::-1],
#     "bot_pct_lim": (0.0, 0.0),
#     "rot_lim": (-5.4, 5.4),
#     "H": 900,
#     "W": 1600,
#     "rand_flip": True,
#     "rot3d_range": [0, 0],
# }

# data = dict(
#     samples_per_gpu=batch_size,
#     workers_per_gpu=batch_size,
#     train=dict(
#         **data_basic_config,
#         ann_file=anno_root + "nuscenes_infos_train.pkl",
#         pipeline=train_pipeline,
#         test_mode=False,
#         data_aug_conf=data_aug_conf,
#         with_seq_flag=True,
#         sequences_split_num=2,
#         keep_consistent_seq_aug=True,
#     ),
#     val=dict(
#         **data_basic_config,
#         ann_file=anno_root + "nuscenes_infos_val.pkl",
#         pipeline=test_pipeline,
#         data_aug_conf=data_aug_conf,
#         test_mode=True,
#         eval_config=eval_config,
#     ),
#     test=dict(
#         **data_basic_config,
#         ann_file=anno_root + "nuscenes_infos_val.pkl",
#         pipeline=test_pipeline,
#         data_aug_conf=data_aug_conf,
#         test_mode=True,
#         eval_config=eval_config,
#     ),
# )

data_aug_conf = {
    "resize_lim": (0.40, 0.47),
    "final_dim": input_shape[::-1],
    "bot_pct_lim": (0.0, 0.0),
    "rot_lim": (-5.4, 5.4),
    "H": 900,
    "W": 1600,
    "rand_flip": True,
    "rot3d_range": [0, 0],
}
# data_aug_conf = {
#     "resize_lim": (1, 1),
#     "final_dim": (928, 1600), #input_shape[::-1],
#     "bot_pct_lim": (0.0, 0.0),
#     "rot_lim": (0, 0),
#     "H": 900,
#     "W": 1600,
#     "rand_flip": False,
#     "rot3d_range": [0, 0],
# }

eval_cfg = {
            "dist_ths": [0.5, 1.0, 2.0, 4.0],
            "dist_th_tp": 2.0,
            "min_recall": 0.1,
            "min_precision": 0.1,
            "mean_ap_weight": 5,
            "class_names":['car','van','truck','bicycle','traffic_sign','traffic_cone','traffic_light','pedestrian'],
            "tp_metrics":['trans_err', 'scale_err', 'orient_err', 'vel_err'],
            "err_name_maping":{'trans_err': 'mATE','scale_err': 'mASE','orient_err': 'mAOE','vel_err': 'mAVE','attr_err': 'mAAE'},
            "class_range":{'car':(50,50),'van':(50,50),'truck':(50,50),'bicycle':(40,40),'traffic_sign':(30,30),'traffic_cone':(30,30),'traffic_light':(30,30),'pedestrian':(40,40)}
            }

data = dict(
    samples_per_gpu=batch_size,
    workers_per_gpu=0, #batch_size,
    train=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file=ann_file_train,
        pipeline=train_pipeline,
        classes=class_names,
        name_mapping=NameMapping,
        map_root=map_root,
        map_file=map_file,
        modality=input_modality,
        patch_size=patch_size,
        bev_size=(bev_h_, bev_w_),
        queue_length=queue_length,
        predict_frames=predict_steps,
        past_frames=past_steps,
        future_frames=fut_steps,
        point_cloud_range=point_cloud_range,
        box_type_3d="LiDAR",
        # add 
        test_mode=False,
        data_aug_conf=data_aug_conf,
        with_seq_flag=True,
        sequences_split_num=1,#2,
        # keep_consistent_seq_aug=True,
    ),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file=ann_file_train,
        pipeline=test_pipeline,
        classes=class_names,
        name_mapping=NameMapping,
        map_root=map_root,
        map_file=map_file,
        modality=input_modality,
        patch_size=patch_size,
        bev_size=(bev_h_, bev_w_),
        queue_length=queue_length,
        predict_frames=predict_steps,
        past_frames=past_steps,
        future_frames=fut_steps,
        point_cloud_range=point_cloud_range,
        box_type_3d="LiDAR",
        # add 
        test_mode=True,
        data_aug_conf=data_aug_conf,
        with_seq_flag=True,
        sequences_split_num=1,#2,
        # keep_consistent_seq_aug=True,
    ),
    # val=dict(
    #     type=dataset_type,
    #     data_root=data_root,
    #     ann_file=ann_file_val,
    #     pipeline=train_pipeline, #test_pipeline,
    #     name_mapping=NameMapping,
    #     map_root=map_root,
    #     map_file=map_file,
    #     bev_size=(bev_h_, bev_w_),
    #     predict_frames=predict_steps,
    #     past_frames=past_steps,
    #     future_frames=fut_steps,
    #     classes=class_names,
    #     modality=input_modality,
    #     samples_per_gpu=1,
    #     point_cloud_range=point_cloud_range,
    #     eval_cfg=eval_cfg,
    #     #eval_mod=['det', 'track', 'map'],
    #     box_type_3d="LiDAR",
    # ),
    # test=dict(
    #     type=dataset_type,
    #     data_root=data_root,
    #     ann_file=ann_file_val,
    #     pipeline=test_pipeline,
    #     name_mapping=NameMapping,
    #     map_root=map_root,
    #     map_file=map_file,
    #     bev_size=(bev_h_, bev_w_),
    #     predict_frames=predict_steps,
    #     past_frames=past_steps,
    #     future_frames=fut_steps,
    #     classes=class_names,
    #     modality=input_modality,
    #     samples_per_gpu=1,
    #     point_cloud_range=point_cloud_range,
    #     eval_cfg=eval_cfg,
    #     #eval_mod=['det', 'track', 'map'],
    #     box_type_3d="LiDAR",
    # ),
    shuffler_sampler=dict(type="DistributedGroupSampler"),
    nonshuffler_sampler=dict(type="DistributedSampler"),
)

# real data
dataset_type_real = "NuScenes3DDataset"
anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
# file_client_args = dict(backend="disk") already defined
train_pipeline_real = [
    dict(type="LoadMultiViewImageFromFiles", to_float32=True),
    dict(
        type="LoadPointsFromFile",
        coord_type="LIDAR",
        load_dim=5,
        use_dim=5,
        file_client_args=file_client_args,
    ),
    dict(type="ResizeCropFlipImage"),
    dict(
        type="MultiScaleDepthMapGenerator",
        downsample=strides[:num_depth_layers],
    ),
    dict(type="BBoxRotation"),
    dict(type="PhotoMetricDistortionMultiViewImage"),
    dict(type="NormalizeMultiviewImage", **img_norm_cfg),
    dict(
        type="CircleObjectRangeFilter",
        class_dist_thred=[55] * len(class_names),#注意需要修改统一
    ),
    dict(type="InstanceNameFilter", classes=class_names),#注意需要修改统一
    dict(
        type='VectorizeMap',
        roi_size=roi_size,
        simplify=False,
        normalize=False,
        sample_num=num_sample,
        permute=True,
    ),
    dict(type="NuScenesSparse4DAdaptor"),
    dict(
        type="Collect",
        keys=[
            "img",
            "timestamp",
            "projection_mat", #lidar2img
            "image_wh",
            "gt_depth",
            "focal",
            "gt_bboxes_3d",
            "gt_labels_3d",
            'gt_map_labels', 
            'gt_map_pts',
            'gt_agent_fut_trajs',
            'gt_agent_fut_masks',
            'gt_ego_fut_trajs',
            'gt_ego_fut_masks',
            'gt_ego_fut_cmd',
            'ego_status',

            # 'map_infos',
            'scene_name',
            'sample_idx',
            'prev',
            'next',
            # 'map_instance_id', #map_instance_id=info['track_globle_id'],
            

        ],
        meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id", "map_instance_id"],
    ),
]
test_pipeline_real = [
    dict(type="LoadMultiViewImageFromFiles", to_float32=True),
    dict(type="ResizeCropFlipImage"),
    dict(type="NormalizeMultiviewImage", **img_norm_cfg),
    dict(type="NuScenesSparse4DAdaptor"),
    dict(
        type="Collect",
        keys=[
            "img",
            "timestamp",
            "projection_mat",
            "image_wh",
            'ego_status',
            'gt_ego_fut_cmd',
        ],
        meta_keys=["T_global", "T_global_inv", "timestamp"],
    ),
]
eval_pipeline_real = [
    dict(
        type="CircleObjectRangeFilter",
        class_dist_thred=[55] * len(class_names),#注意需要修改统一
    ),
    dict(type="InstanceNameFilter", classes=class_names), #注意需要修改统一
    dict(
        type='VectorizeMap',
        roi_size=roi_size,
        simplify=True,
        normalize=False,
    ),
    dict(
        type='Collect', 
        keys=[
            'vectors',
            "gt_bboxes_3d",
            "gt_labels_3d",
            'gt_agent_fut_trajs',
            'gt_agent_fut_masks',
            'gt_ego_fut_trajs',
            'gt_ego_fut_masks', 
            'gt_ego_fut_cmd',
            'fut_boxes'
        ],
        meta_keys=['token', 'timestamp']
    ),
]

data_basic_config = dict(
    type=dataset_type_real,
    data_root=data_root,
    classes=class_names,
    map_classes=map_class_names,
    modality=input_modality,
    version="v1.0-trainval",
)
eval_config = dict(
    **data_basic_config,
    ann_file=anno_root + 'nuscenes_infos_val.pkl',
    pipeline=eval_pipeline_real,
    test_mode=True,
)

data_real = dict(
    samples_per_gpu=batch_size,
    workers_per_gpu=0, #batch_size,
    train=dict(
        **data_basic_config,
        ann_file=anno_root + "nuscenes_infos_train.pkl",
        pipeline=train_pipeline_real,
        test_mode=False,
        data_aug_conf=data_aug_conf,
        with_seq_flag=True,
        sequences_split_num=2,
        keep_consistent_seq_aug=True,
    ),
    val=dict(
        **data_basic_config,
        ann_file=anno_root + "nuscenes_infos_val.pkl",
        pipeline=test_pipeline_real,
        data_aug_conf=data_aug_conf,
        test_mode=True,
        eval_config=eval_config,
    ),
    test=dict(
        **data_basic_config,
        ann_file=anno_root + "nuscenes_infos_val.pkl",
        pipeline=test_pipeline_real,
        data_aug_conf=data_aug_conf,
        test_mode=True,
        eval_config=eval_config,
    ),
)



# ================== training ========================
optimizer = dict(
    type="AdamW",
    lr=4e-4,
    weight_decay=0.001, #0.001,
    paramwise_cfg=dict(
        custom_keys={
            "img_backbone": dict(lr_mult=0.5),
        }
    ),
)
optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) # 25 to 35
# optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) # 25 to 35

optimizer_source = dict(
    type="AdamW",
    lr=4e-4,
    weight_decay=0.001, #0.001,
    paramwise_cfg=dict(
        custom_keys={
            "img_backbone": dict(lr_mult=0.5),
        }
    ),
)
optimizer_Discriminator = dict(
    type="AdamW",
    lr=4e-4,
    weight_decay=0.001, #0.001,
    # paramwise_cfg=dict(
    #     custom_keys={
    #         "img_backbone": dict(lr_mult=0.5),
    #     }
    # ),
)

lr_config = dict(
    policy="CosineAnnealing",
    warmup="linear",
    warmup_iters=1500, #500
    warmup_ratio=1.0 / 3,
    min_lr_ratio=1e-3, #5e-2, #1e-3, 
)
runner = dict(
    type="IterBasedRunnerBoth",
    max_iters=num_iters_per_epoch * num_epochs,
)

# ================== eval ========================
eval_mode = dict(
    with_det=True, #True, #False, #True,
    with_tracking=False, #True,
    with_map=False,#True,
    with_motion=False,
    with_planning=False,
    tracking_threshold=0.2,
    motion_threshhold=0.2,
)
evaluation = dict(
    interval=num_iters_per_epoch*checkpoint_epoch_interval,
    eval_mode=eval_mode,
)