# =========== misc config ==============
optimizer_wrapper = dict(
    optimizer = dict(
        type='AdamW',
        lr=4e-4,
        weight_decay=0.01,
    ),
    paramwise_cfg=dict(
        custom_keys={
            'backbone': dict(lr_mult=0.1),}
    ),
)
grad_max_norm = 35
amp = False

# =========== base config ==============
seed = 1
print_freq = 50
eval_freq = 1
max_epochs = 6
load_from = None
find_unused_parameters = True

# =========== data config ==============
ignore_label = 0
empty_idx = 17   # 0 noise, 1~16 objects, 17 empty
cls_dims = 18
pc_range = [-50.0, -50.0, -5.0, 50.0, 50.0, 3.0]
image_size = [864, 1600]
resize_lim = [1.0, 1.0]
flip = True
num_frames = 1
offset = 0

# =========== model config =============
_dim_ = 128
num_cams = 6
num_heads = 4
num_levels = 4
drop_out = 0.1
semantics_activation = 'identity'
semantic_dim = 17
include_opa = True
wempty = False
freeze_perception = False

num_anchor = 6400
random_samples = 6400
scale_range = [0.01, 2.5]
num_learnable_pts = 6
learnable_scale = 3
scale_multiplier = 5
num_encoder = 4
return_layer_idx = [2, 3]

model = dict(
    type='GaussianLifterV2',
    freeze=False,
    num_anchor=num_anchor,
    embed_dims=_dim_,
    ignore_label=ignore_label,
    anchor_grad=False,
    feat_grad=False,
    semantic_dim=semantic_dim,
    include_opa=include_opa,
    num_samples=128,
    max_depth=72.0,
    pc_range=pc_range,
    voxel_size=0.5,
    occ_resolution=[200, 200, 16],
    empty_label=empty_idx,
    anchors_per_pixel=1,
    random_sampling=False,
    projection_in=None,
    initializer=dict(
        type="ResNetSecondFPN",
        img_backbone_out_indices=[0, 1, 2, 3],
        img_backbone_config=dict(
            type='ResNet',
            depth=101,
            num_stages=4,
            out_indices=(0, 1, 2, 3),
            frozen_stages=1,
            norm_cfg=dict(type='BN2d', requires_grad=False),
            norm_eval=True,
            style='caffe',
            with_cp=True,
            dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), # original DCNv2 will print log when perform load_state_dict
            stage_with_dcn=(False, False, True, True)),
        neck_confifg=dict(
            type='SECONDFPN',
            in_channels=[256, 512, 1024, 2048],
            out_channels=[_dim_] * 4,
            upsample_strides=[0.5, 1, 2, 4])),
    initializer_img_downsample=None,
    # pretrained_path="work_dir/surroundocc/lifterv2/gs12800/lifter.pth",
    deterministic=False,
    random_samples=random_samples,
)


loss = dict(
    type='MultiLoss',
    loss_cfgs=[
        dict(
            type="PixelDistributionLoss",
            weight=1.0,
            use_sigmoid=False,
            input_dict={
                'pixel_logits': 'pixel_logits',
                'pixel_gt': 'pixel_gt',}),
    ]
)

data_path = 'data/surroundocc'

train_dataset_config = dict(
    type='NuScenes_Scene_SurroundOcc_Dataset',
    data_path = data_path,
    num_frames = num_frames,
    offset = offset,
    empty_idx=empty_idx,
    imageset = 'data/nuscenes_temporal_infos_train.pkl',
)

val_dataset_config = dict(
    type='NuScenes_Scene_SurroundOcc_Dataset',
    data_path = data_path,
    num_frames = num_frames,
    offset = offset,
    empty_idx=empty_idx,
    imageset = 'data/nuscenes_temporal_infos_val.pkl',
)

train_wrapper_config = dict(
    type='NuScenes_Scene_Occ_DatasetWrapper',
    final_dim = image_size,
    resize_lim = resize_lim,
    flip = flip,
    phase='train', 
)

val_wrapper_config = dict(
    type='NuScenes_Scene_Occ_DatasetWrapper',
    final_dim = image_size,
    resize_lim = resize_lim,
    flip = flip,
    phase='val', 
)

train_loader_config = dict(
    batch_size = 1,
    shuffle = True,
    num_workers = 8,
)
    
val_loader_config = dict(
    batch_size = 1,
    shuffle = False,
    num_workers = 8,
)