# Copyright (c) Facebook, Inc. and its affiliates.
# Adapted for DUPS from AutoFocusFormer

from detectron2.config import CfgNode as CN


def add_maskformer2_config(cfg):
    """
    Add config for MASK_FORMER.
    """
    # NOTE: configs from original maskformer
    # data config
    # select the dataset mapper
    cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
    # Color augmentation
    cfg.INPUT.COLOR_AUG_SSD = False
    # We retry random cropping until no single category in semantic segmentation GT occupies more
    # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
    cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
    # Pad image and segmentation GT in dataset mapper.
    cfg.INPUT.SIZE_DIVISIBILITY = -1

    # solver config
    # weight decay on embedding
    cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
    # optimizer
    cfg.SOLVER.OPTIMIZER = "ADAMW"
    cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
    cfg.SOLVER.BETAS = (0.9, 0.999)
    cfg.SOLVER.EPSILON = 1e-8
    cfg.SOLVER.CHECKPOINT_PERIOD = 2500

    # mask_former model config
    cfg.MODEL.MASK_FORMER = CN()

    # loss
    cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
    cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
    cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = 1.0
    cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
    cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0

    # transformer config
    cfg.MODEL.MASK_FORMER.NHEADS = 8
    cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
    cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
    cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
    cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
    cfg.MODEL.MASK_FORMER.PRE_NORM = False

    cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
    cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100

    cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
    cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False

    # mask_former inference config
    cfg.MODEL.MASK_FORMER.TEST = CN()
    cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True
    cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False
    cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
    cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
    cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
    cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False

    # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
    # you can use this config to override
    cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32

    # Only used by MetaLoss version
    cfg.MODEL.MASK_FORMER.METALOSS_WEIGHT = 5.0

    # pixel decoder config
    cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
    # adding transformer in pixel decoder
    cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
    # pixel decoder
    cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "MSDeformAttnPixelDecoder"
    cfg.MODEL.SEM_SEG_HEAD.FPN_COMMON_STRIDE = 4
    cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES = ["res2", "res3", "res4", "res5"]
    cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE = 255
    cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 150
    cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT = 1.0


    # NOTE: maskformer2 extra configs
    # transformer module
    cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME = "MultiScaleMaskedTransformerDecoder"

    # LSJ aug
    cfg.INPUT.IMAGE_SIZE = 1024
    cfg.INPUT.MIN_SCALE = 0.1
    cfg.INPUT.MAX_SCALE = 2.0

    # MSDeformAttn encoder configs
    cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
    cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
    cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8

    # point loss configs
    # Number of points sampled during training for a mask point head.
    cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = 112 * 112
    # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
    # original paper.
    cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0
    # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
    # the original paper.
    cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75



    # DUPS MaskFormer config

    cfg.MODEL.MASK_FORMER.UPSAMPLING_WEIGHT = 10
    cfg.MODEL.MASK_FORMER.NUM_RESOLUTION_SCALES = 4
    cfg.MODEL.MASK_FORMER.ORACLE_TEACHER_RATIO = 0.0
    cfg.MODEL.MASK_FORMER.MASK_DECODER_ALL_LEVELS = False
    cfg.MODEL.MASK_FORMER.MASK_DIM = 256
    cfg.MODEL.MASK_FORMER.SHEPARD_POWER = 6.0
    cfg.MODEL.MASK_FORMER.SHEPARD_POWER_LEARNABLE = True
    cfg.MODEL.MASK_FORMER.DECODER_LEVELS = 3
    cfg.MODEL.SEM_SEG_HEAD.MLP_RATIO = 4.0
    cfg.MODEL.SEM_SEG_HEAD.NHEADS = 8
    cfg.MODEL.SEM_SEG_HEAD.DROPOUT = 0.0


    # DUPS backbone config

    cfg.MODEL.DUPS = CN()
    cfg.MODEL.DUPS.NAME = ["MixResViT","MixResNeighbour", "MixResNeighbour", "MixResNeighbour", "MixResNeighbour", "MixResNeighbour", "MixResViT" ]
    cfg.MODEL.DUPS.EMBED_DIM = [512,256,128,64,128,256,512]
    cfg.MODEL.DUPS.DEPTHS = [1, 1, 1, 4, 4, 16, 4]
    cfg.MODEL.DUPS.NUM_HEADS = [ 16, 8, 4, 2, 4, 8, 16 ]
    cfg.MODEL.DUPS.PATCH_SIZES = [32, 16, 8, 4, 8, 16, 32]
    cfg.MODEL.DUPS.SPLIT_RATIO = [4, 4, 4, 4, 4, 4, 4]
    cfg.MODEL.DUPS.MLP_RATIO = [3., 3., 3., 3., 3., 3., 3.]
    cfg.MODEL.DUPS.UPSCALE_RATIO = [0.0, 0.7, 0.7, 0.6, 0.0, 0.0, 0.0]
    cfg.MODEL.DUPS.DROP_RATE = [0.0, 0.0, 0.0, 0.0]
    cfg.MODEL.DUPS.DROP_PATH_RATE = 0.0
    cfg.MODEL.DUPS.ATTN_DROP_RATE = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    cfg.MODEL.DUPS.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
    cfg.MODEL.DUPS.CLUSTER_SIZE = [8, 8, 8, 8, 8, 8, 8]
    cfg.MODEL.DUPS.NBHD_SIZE = [48,48,48,48,48,48,48]
    cfg.MODEL.DUPS.KEEP_OLD_SCALE = True
    cfg.MODEL.DUPS.ADD_IMAGE_DATA_TO_ALL = False
    cfg.MODEL.DUPS.LAYER_SCALE = 0.0
    cfg.MODEL.DUPS.NUM_REGISTER_TOKENS = 0
    cfg.MODEL.DUPS.DYNAMIC_UPSAMPLING_RATIOS = True
    cfg.MODEL.DUPS.DYNAMIC_UPSAMPLING_THRESHOLD = [0.0, 0.01, 0.02, 0.04, 0.0, 0.0, 0.0]

    cfg.TEST.SW_STRIDE = [768, 768]
    cfg.TEST.SW_CROP_SIZE = [1024, 1024]