_base_ = ["../_base_/default_runtime.py"]
seed = None
# crop_h = 560
# crop_w = 840
# crop_h = 630
# crop_w = 1120
crop_h = 518
crop_w = 518
patch_size = 14
# misc custom setting
# batch_size = 2  # bs: total bs in all gpus
# num_worker = 4
# batch_size = 16  # bs: total bs in all gpus
# num_worker = 128
# batch_size = 48  # bs: total bs in all gpus
# num_worker = 320
# batch_size = 64  # bs: total bs in all gpus
# num_worker = 192
batch_size = 96  # bs: total bs in all gpus
num_worker = 320
# batch_size = 24  # bs: total bs in all gpus
# num_worker = 48
# batch_size = 16  # bs: total bs in all gpus
# num_worker = 96
# batch_size = 8  # bs: total bs in all gpus
# num_worker = 8
# batch_size = 96  # bs: total bs in all gpus
# num_worker = 96
mix_prob = 0
clip_grad = 3.0
# empty_cache = False
empty_cache = True
enable_amp = True
amp_dtype = "bfloat16"
evaluate = False
# find_unused_parameters = False
find_unused_parameters = True

# only_weight_backbone = True
only_weight_backbone = False
# model settings
model = dict(
    type="Concerto-v2m2_upcast",
    patch_h=crop_h // patch_size,
    patch_w=crop_w // patch_size,
    view_num=2,
    # model_name="dinov2_vits14",
    dinomodel_name="dinov2_vitl14",
    # dinomodel_name="dinov2_vitg14",
    # backbone_out_channels=512,
    backbone_out_channels=1184,
    # backbone_out_channels=1088,
    # backbone_out_channels=1232,
    # backbone_out_channels=992,
    # backbone_out_channels=960,
    # backbone_out_channels=896,
    # backbone_out_channels=64,
    # backbone_out_channels=384,
    embedding_channels=64,
    student_pretrained=False,
    feature_type="patch",
    dinov2_upcast_level=3,
    # backbone - student & teacher
    backbone=dict(
        type="Concerto-PTv3-v2m1",
        in_channels=9,
        order=("z", "z-trans", "hilbert", "hilbert-trans"),
        stride=(2, 2, 2, 2),
        # enc_depths=(2, 2, 2, 6, 2),
        # enc_channels=(32, 64, 128, 256, 512),
        # enc_num_head=(2, 4, 8, 16, 32),
        enc_depths=(3, 3, 3, 12, 3),
        enc_channels=(48, 96, 192, 384, 512),
        enc_num_head=(3, 6, 12, 24, 32),
        enc_patch_size=(1024, 1024, 1024, 1024, 1024),
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        drop_path=0.3,
        shuffle_orders=True,
        pre_norm=True,
        enable_rpe=False,
        enable_flash=True,
        upcast_attention=False,
        upcast_softmax=False,
        cls_mode=True,
        # cls_mode=False,
        # upcast_mode=True,
        upcast_mode=False,
        upcast_normalization=False,
        upcast_layers=4,
        traceable=True,
        mask_token=True,
    ),
    teacher_custom=dict(
        attn_drop=0.0,
        proj_drop=0.0,
        drop_path=0.0,
    ),
    head_in_channels=1088,
    # head_in_channels=896,
    head_hidden_channels=4096,
    head_embed_channels=256,
    head_num_prototypes=4096,
    DINOhead_in_channels=384,
    DINOhead_hidden_channels=4096,
    DINOhead_embed_channels=256,
    DINOhead_num_prototypes=4096,
    # DINOhead_num_prototypes=384,
    num_global_view=2,
    num_local_view=4,
    mask_size_start=0.1,
    mask_size_base=0.4,
    mask_size_warmup_ratio=0.05,
    mask_ratio_start=0.3,
    mask_ratio_base=0.7,
    mask_ratio_warmup_ratio=0.05,
    mask_jitter=0.01,
    teacher_temp_start=0.04,
    teacher_temp_base=0.07,
    teacher_temp_warmup_ratio=0.05,
    student_temp=0.1,
    # mask_loss_weight=0 / 10,
    # roll_mask_loss_weight=0 / 10,
    # unmask_loss_weight=0 / 10,
    # dinov2_loss_weight=5 / 10,
    # origin_loss_weight=5 / 10,
    # mask_loss_weight=2 / 10,
    # roll_mask_loss_weight=2 / 10,
    # unmask_loss_weight=4 / 10,
    # dinov2_loss_weight=2 / 10,
    # origin_loss_weight=0 / 10,
    mask_loss_weight=1 / 6,
    roll_mask_loss_weight=1 / 6,
    unmask_loss_weight=2 / 6,
    dinov2_loss_weight=2 / 6,
    origin_loss_weight=0 / 6,
    # mask_loss_weight=1 / 8,
    # roll_mask_loss_weight=1 / 8,
    # unmask_loss_weight=2 / 8,
    # dinov2_loss_weight=4 / 8,
    # origin_loss_weight=0 / 10,
    # mask_loss_weight=1 / 10,
    # roll_mask_loss_weight=1 / 10,
    # unmask_loss_weight=2 / 10,
    # dinov2_loss_weight=6 / 10,
    # origin_loss_weight=0 / 10,
    # mask_loss_weight=1 / 6,
    # roll_mask_loss_weight=1 / 6,
    # unmask_loss_weight=2 / 6,
    # dinov2_loss_weight=2 / 6,
    # origin_loss_weight=0 / 10,
    # mask_loss_weight=2 / 8,
    # roll_mask_loss_weight=2 / 8,
    # unmask_loss_weight=4 / 8,
    # dinov2_loss_weight=0 / 10,
    # origin_loss_weight=0 / 10,
    # mask_loss_weight=0 / 8,
    # roll_mask_loss_weight=0 / 8,
    # unmask_loss_weight=0 / 8,
    # dinov2_loss_weight=10 / 10,
    # origin_loss_weight=0 / 10,
    momentum_base=0.994,
    momentum_final=1,
    # momentum_base=0,
    # momentum_final=0,
    match_max_k=8,
    match_max_r=0.32,
    up_cast_level=2,
    normalize_method="sk",
    dino_sk=False,
    # dino_sk = True,
    cos_head=False,
    # cos_head = True,
    dino_cos_shift=True,
    # dino_cos_shift=False,
    dino_loss_type="cos",
    sonata_loss_type="entropy",
    dino_head_type="predictor",
    dino_head_update_method="copy",
    # dino_head_update_method="ema",
    sonata_head_type="symmetric",
    # sonata_head_type="predictor",
    sonata_head_update_method="ema",
    # sonata_head_update_method="copy",
    sonata_model_type="online",
    # sonata_model_type="offline",
)

# scheduler settings
epoch = 100
base_lr = 0.004
# base_lr = 0.003
# base_lr = 0.002
lr_decay = 0.9  # layer-wise lr decay

base_wd = 0.04  # wd scheduler enable in hooks
final_wd = 0.2  # wd scheduler enable in hooks

dec_depths = model["backbone"]["enc_depths"]
param_dicts = [
    dict(
        keyword=f"enc{e}.block{b}.",
        lr=base_lr * lr_decay ** (sum(dec_depths) - sum(dec_depths[:e]) - b - 1),
    )
    for e in range(len(dec_depths))
    for b in range(dec_depths[e])
]
del dec_depths

optimizer = dict(type="AdamW", lr=base_lr, weight_decay=base_wd)
scheduler = dict(
    type="OneCycleLR",
    max_lr=[base_lr] + [g["lr"] for g in param_dicts],
    pct_start=0.05,
    anneal_strategy="cos",
    div_factor=10.0,
    final_div_factor=1000.0,
)

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

# dataset settings
transform = [
    dict(
        type="ImgAugmentation",
        crop_h=crop_h,
        crop_w=crop_w,
        patch_h=crop_h // patch_size,
        patch_w=crop_w // patch_size,
        patch_size=patch_size,
        # imgtransforms=[
        #     # dict(type="ImgRandomHorizontalFlip", p=0.5),
        #     dict(type="ImgRandomColorJitter", brightness=0.4, contrast=0.4, saturation=0.2, p=0.8),
        #     dict(type="ImgRandomGrayscale", p=0.2),
        #     dict(type="ImgGaussianBlur", p=0.5),
        #     # dict(type="ImgGaussianBlur", p=1.0),
        #     dict(type="Imgnormalize",mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
        # ],
        imgtransforms=[
            # dict(type="ImgRandomHorizontalFlip", p=0.5),
            # dict(type="ImgRandomColorJitter", brightness=0.2, contrast=0.2, saturation=0.1, p=0.8),
            # dict(type="ImgRandomGrayscale", p=0.),
            # dict(type="ImgPixelContrast", threshold = 0.02, p=0.5),
            # dict(type="ImgChromaticJitter", p=0.95, std=0.05),
            # dict(type="ImgGaussianBlur", p=0.5),
            # dict(type="ImgToTensor"),
            dict(
                type="Imgnormalize",
                mean=IMAGENET_DEFAULT_MEAN,
                std=IMAGENET_DEFAULT_STD,
            ),
        ],
        # imgtransforms=[
        #     # dict(type="ImgRandomHorizontalFlip", p=0.5),
        #     # dict(type="ImgRandomColorJitter", brightness=0.2, contrast=0.2, saturation=0.1, p=0.8),
        #     # dict(type="ImgRandomGrayscale", p=0.),
        #     dict(type="ImgPixelContrast", threshold = 0.02, p=1),
        #     dict(type="ImgChromaticJitter", p=1, std=0.05),
        #     dict(type="ImgGaussianBlur", p=1),
        #     # dict(type="ImgToTensor"),
        #     dict(type="Imgnormalize",mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
        # ],
    ),
    dict(type="GridSample", grid_size=0.02, hash_type="fnv", mode="train"),
    dict(type="Copy", keys_dict={"coord": "origin_coord"}),
    # dict(type="Copy", keys_dict={"color": "origin_color"}),
    # dict(type="Copy", keys_dict={"coord": "dino_coord"}),
    # dict(type="Copy", keys_dict={"color": "dino_color"}),
    # dict(type="Copy", keys_dict={"normal": "dino_normal"}),
    dict(
        type="MultiViewGeneratorV3",
        global_view_num=2,
        global_view_scale=(0.4, 1.0),
        local_view_num=4,
        local_view_scale=(0.1, 0.4),
        global_shared_transform=[
            dict(
                type="RandomColorJitter",
                brightness=0.4,
                contrast=0.4,
                saturation=0.2,
                hue=0.02,
                p=0.8,
            ),
            dict(type="ChromaticTranslation", p=0.95, ratio=0.05),
            # dict(type="ChromaticJitter", p=0.95, std=0.05),
            # dict(type="SphereCrop", point_max=102400, mode="given"),
            dict(type="NormalizeColor"),
        ],
        global_transform=[
            dict(type="CenterShift", apply_z=True),
            dict(type="RandomScale", scale=[0.9, 1.1]),
            dict(type="RandomRotate", angle=[-1, 1], axis="z", center=[0, 0, 0], p=0.8),
            dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="x", p=0.8),
            dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="y", p=0.8),
            dict(type="RandomFlip", p=0.5),
            dict(type="RandomJitter", sigma=0.005, clip=0.02),
            dict(type="ElasticDistortion", distortion_params=[[0.2, 0.4], [0.8, 1.6]]),
        ],
        local_transform=[
            dict(type="CenterShift", apply_z=True),
            dict(type="RandomScale", scale=[0.9, 1.1]),
            dict(type="RandomRotate", angle=[-1, 1], axis="z", center=[0, 0, 0], p=0.8),
            dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="x", p=0.8),
            dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="y", p=0.8),
            dict(type="RandomFlip", p=0.5),
            dict(type="RandomJitter", sigma=0.005, clip=0.02),
            dict(type="ElasticDistortion", distortion_params=[[0.2, 0.4], [0.8, 1.6]]),
            # dict(type="ChromaticAutoContrast", p=0.2, blend_factor=None),
            dict(
                type="RandomColorJitter",
                brightness=0.4,
                contrast=0.4,
                saturation=0.2,
                hue=0.02,
                p=0.8,
            ),
            dict(type="ChromaticTranslation", p=0.95, ratio=0.05),
            # dict(type="ChromaticJitter", p=0.95, std=0.05),
            dict(type="NormalizeColor"),
        ],
        max_size=65536,
        dinov2_max_size=65536,
        # dinov2_max_size=102400,
        dinov2_scale=(0.8, 1),
    ),
    # dict(
    #     type="DINOv2ViewGenerator",
    #     view_keys=("dino_coord", "dino_color", "dino_normal", "mask_index","name","imgs"),
    #     dino_transform=[
    #         dict(type="CenterShift", apply_z=True),
    #         dict(
    #             type="RandomDropout",
    #             dropout_ratio=0.2,
    #             dropout_application_ratio=0.2,
    #         ),
    #         dict(
    #             type="RandomRotate",
    #             angle=[-1, 1],
    #             axis="z",
    #             center=[0, 0, 0],
    #             p=0.5,
    #         ),
    #         dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="x", p=0.5),
    #         dict(type="RandomRotate", angle=[-1 / 64, 1 / 64], axis="y", p=0.5),
    #         dict(type="RandomScale", scale=[0.9, 1.1]),
    #         dict(type="RandomFlip", p=0.5),
    #         dict(type="RandomJitter", sigma=0.005, clip=0.02),
    #         dict(
    #             type="ElasticDistortion",
    #             distortion_params=[[0.2, 0.4], [0.8, 1.6]],
    #         ),
    #         dict(type="ChromaticAutoContrast", p=0.2, blend_factor=None),
    #         dict(type="ChromaticTranslation", p=0.95, ratio=0.05),
    #         dict(type="ChromaticJitter", p=0.95, std=0.05),
    #         dict(
    #             type="GridSample",
    #             grid_size=0.02,
    #             # grid_size=0.01,
    #             hash_type="fnv",
    #             mode="train",
    #             keys=(
    #                 "coord",
    #                 "color",
    #                 "normal",
    #                 "name",
    #                 "imgs",
    #                 "mask_index",
    #             ),
    #             return_grid_coord=True,
    #         ),
    #         dict(type="SphereCrop", point_max=102400, mode="given"),
    #         dict(type="CenterShift", apply_z=False),
    #         dict(type="NormalizeColor"),
    #         dict(type="ToTensor"),
    #         dict(
    #             type="Collect",
    #             keys=(
    #                 "coord",
    #                 "color",
    #                 "normal",
    #                 "grid_coord",
    #                 # "segment",
    #                 "imgs",
    #                 "name",
    #                 "mask_index",
    #             ),
    #             feat_keys=("coord","color","normal"),
    #         ),
    #     ],
    # ),
    dict(type="ToTensor"),
    dict(type="Add", keys_dict={"grid_size": 0.02}),
    dict(
        type="Collect",
        keys=(
            # "coord",
            # "color",
            # "origin_coord",
            # "origin_color",
            "global_origin_coord",
            "global_coord",
            "global_color",
            "global_offset",
            "local_origin_coord",
            "local_coord",
            "local_color",
            "local_offset",
            "grid_size",
            "name",
            # "dino_mask_index",
            "imgs",
            # "dino_feature",
            # "dino_coord",
            # "dino_color",
            # "dino_normal",
            # "dino_offset",
            "global_mask_index",
            "img_num",
        ),
        offset_keys_dict=dict(),
        global_feat_keys=("global_coord", "global_color", "global_normal"),
        local_feat_keys=("local_coord", "local_color", "local_normal"),
        # dino_feat_keys=("dino_coord", "dino_color", "dino_normal"),
    ),
]

data = dict(
    train=dict(
        type="ConcatDataset",
        datasets=[
            # arkit
            dict(
                # type="Dust3rDINOScanNetPPDataset",
                type="ScanNetPPDatasetALL_img",
                # type="ScanNetDataset",
                crop_h=crop_h,
                crop_w=crop_w,
                patch_size=patch_size,
                split=["train", "val", "test"],
                # data_root="data/scannetpp_processed_all_split/scene",
                # data_root="/high_perf_store3/l3_data/xxx/datasets/ar_img",
                data_root="/high_perf_store3/l3_data/xxx/datasets/arkit_img_compressed/arkitscenes_trans_all_split",
                # data_root="data/scannet",
                transform=transform,
                test_mode=False,
                loop=1,
            ),
            # ScanNet 1,613
            dict(
                # type="Dust3rDINOScanNetPPDataset",
                type="ScanNetPPDatasetALL_img",
                # type="ScanNetDataset",
                crop_h=crop_h,
                crop_w=crop_w,
                patch_size=patch_size,
                split=["train", "val", "test"],
                data_root="/high_perf_store3/l3_data/xxx/datasets/sc_img_3",
                # data_root="/high_perf_store3/l3_data/xxx/datasets/sc_img_2",
                # data_root="data/scannet",
                transform=transform,
                test_mode=False,
                loop=1,
            ),
            # ScanNet++ 1,016
            dict(
                type="ScanNetPPDatasetALL_img",
                # type="ScanNetPPDataset",
                crop_h=crop_h,
                crop_w=crop_w,
                patch_size=patch_size,
                split=[
                    "train",
                    "val",
                    "test",
                ],
                data_root="/high_perf_store3/l3_data/xxx/datasets/scppv2_img/scppv2_img",
                # data_root="data/scannetpp",
                transform=transform,
                test_mode=False,
                loop=1,
            ),
            # S3DIS 272
            dict(
                type="S3DISDatasetALL_img",
                crop_h=crop_h,
                crop_w=crop_w,
                patch_size=patch_size,
                split=["train"],
                # split=["Area_1", "Area_2", "Area_3", "Area_4", "Area_5", "Area_6"],
                data_root="/high_perf_store3/l3_data/xxx/datasets/s3dis_origin/s3dis_preprocessed_split",
                transform=transform,
                test_mode=False,
                loop=1,
            ),
            # HM3D 11,493
            dict(
                type="RE10KDatasetALL_img",
                crop_h=crop_h,
                crop_w=crop_w,
                patch_size=patch_size,
                split=["train", "val"],
                data_root="/high_perf_store3/l3_data/xxx/datasets/hm3d_img/hm3d_preprocessed",
                transform=transform,
                test_mode=False,
                # force_label=False,
                loop=1,
            ),
            # Structured3D 21,821
            dict(
                # type="Structured3DDataset",
                type="Structured3DDatasetALL_img",
                crop_h=crop_h,
                crop_w=crop_w,
                patch_size=patch_size,
                split=["train", "val", "test"],
                # data_root="data/structured3d",
                data_root="/high_perf_store3/l3_data/xxx/datasets/st3d_img/st3d_img",
                transform=transform,
                test_mode=False,
                loop=1,
            ),
            # RE10K 89.976
            dict(
                type="RE10KDatasetALL_img",
                crop_h=crop_h,
                crop_w=crop_w,
                patch_size=patch_size,
                split=["train", "val", "test"],
                data_root="/high_perf_store3/l3_data/xxx/datasets/re10k/re10k_img",
                transform=transform,
                test_mode=False,
                loop=1,
            ),
        ],
    )
)

hooks = [
    dict(type="CheckpointLoader"),
    dict(type="ModelHook"),
    dict(type="WeightDecaySchedular", base_value=base_wd, final_value=final_wd),
    dict(type="IterationTimer", warmup_iter=2),
    dict(type="InformationWriter"),
    dict(
        type="PCAEvaluator",
        eval_step=1,
        point_size=0.03,
        dataset=dict(
            type="DefaultDataset",
            split=["scannet", "s3dis", "hm3d", "structured3d"],
            data_root="data/collection",
            transform=[
                dict(type="CenterShift", apply_z=True),
                dict(
                    type="GridSample",
                    grid_size=0.02,
                    hash_type="fnv",
                    mode="train",
                    return_grid_coord=True,
                ),
                dict(type="CenterShift", apply_z=False),
                dict(type="NormalizeColor"),
                dict(type="ToTensor"),
                dict(
                    type="Collect",
                    keys=("coord", "grid_coord", "segment", "name", "split"),
                    feat_keys=("coord", "color", "normal"),
                ),
            ],
            test_mode=False,
        ),
    ),
    dict(
        type="InternalMatchingEvaluator",
        eval_step=1,
        point_size=0.03,
        segment_ignore_index=(-1, 0, 1),
        dataset=dict(
            type="DefaultDataset",
            split=["scannet", "s3dis", "hm3d", "structured3d"],
            data_root="data/collection",
            transform=[
                dict(
                    type="MultiViewsGenerator",
                    view_names=("global", "local"),
                    view_repeats=(1, 1),
                    view_keys=("coord", "color", "normal", "segment"),
                    view_trans=[
                        # view augmentations for global view
                        [
                            dict(type="CenterShift", apply_z=True),
                            dict(type="NormalizeColor"),
                            dict(
                                type="GridSample",
                                grid_size=0.02,
                                hash_type="fnv",
                                mode="train",
                                keys=("coord", "color", "normal", "segment"),
                                return_grid_coord=False,
                            ),
                        ],
                        [
                            dict(type="SphereCrop", scale=0.35, mode="random"),
                            dict(type="CenterShift", apply_z=True),
                            dict(
                                type="RandomRotate",
                                angle=[-1, 1],
                                axis="z",
                                center=[0, 0, 0],
                                p=1,
                            ),
                            dict(type="RandomFlip", p=0.5),
                            dict(type="RandomJitter", sigma=0.005, clip=0.02),
                            dict(
                                type="ChromaticAutoContrast", p=0.2, blend_factor=None
                            ),
                            dict(type="ChromaticTranslation", p=0.95, ratio=0.05),
                            dict(type="NormalizeColor"),
                            dict(
                                type="GridSample",
                                grid_size=0.02,
                                hash_type="fnv",
                                mode="train",
                                keys=("coord", "color", "normal", "segment"),
                                return_grid_coord=True,
                            ),
                        ],
                    ],
                ),
                dict(type="ToTensor"),
                dict(type="Add", keys_dict={"grid_size": 0.02}),
                dict(
                    type="Collect",
                    keys=(
                        "global_coord",
                        "global_color",
                        "global_segment",
                        "global_offset",
                        "local_coord",
                        "local_color",
                        "local_segment",
                        "local_offset",
                        "grid_size",
                        "name",
                        "split",
                        "global_mask_index",
                    ),
                    offset_keys_dict=dict(),
                    global_feat_keys=("global_coord", "global_color", "global_normal"),
                    local_feat_keys=("local_coord", "local_color", "local_normal"),
                ),
            ],
            test_mode=False,
        ),
    ),
    dict(type="CheckpointSaver", save_freq=10),
]
