# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved

import os
from detectron2.config import CfgNode as CN


def add_mask_former_default_config(cfg):
    # data config
    # select the dataset mapper
    cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
    # Color augmentation
    cfg.INPUT.COLOR_AUG_SSD = False
    # We retry random cropping until no single category in semantic segmentation GT occupies more
    # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
    cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
    # Pad image and segmentation GT in dataset mapper.
    cfg.INPUT.SIZE_DIVISIBILITY = -1

    # solver config
    # test batch size
    cfg.SOLVER.TEST_IMS_PER_BATCH = 1
    # weight decay on embedding
    cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
    # optimizer
    cfg.SOLVER.OPTIMIZER = "ADAMW"
    cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1

    # mask_former model config
    cfg.MODEL.MASK_FORMER = CN()

    # loss
    cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
    cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
    cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
    cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0

    # transformer config
    cfg.MODEL.MASK_FORMER.NHEADS = 8
    cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
    cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
    cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
    cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
    cfg.MODEL.MASK_FORMER.PRE_NORM = False

    cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
    cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100

    cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
    cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False

    # mask_former inference config
    cfg.MODEL.MASK_FORMER.TEST = CN()
    cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
    cfg.MODEL.MASK_FORMER.TEST.USE_GT = False
    cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
    cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
    cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False

    # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
    # you can use this config to override
    cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32

    # pixel decoder config
    cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
    # adding transformer in pixel decoder
    cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
    # pixel decoder
    cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"

    # swin transformer backbone
    cfg.MODEL.SWIN = CN()
    cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
    cfg.MODEL.SWIN.PATCH_SIZE = 4
    cfg.MODEL.SWIN.EMBED_DIM = 96
    cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
    cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
    cfg.MODEL.SWIN.WINDOW_SIZE = 7
    cfg.MODEL.SWIN.MLP_RATIO = 4.0
    cfg.MODEL.SWIN.QKV_BIAS = True
    cfg.MODEL.SWIN.QK_SCALE = None
    cfg.MODEL.SWIN.NORM_INDICES = None
    cfg.MODEL.SWIN.PROJECTION = False
    cfg.MODEL.SWIN.PROJECT_DIM = 256
    cfg.MODEL.SWIN.DROP_RATE = 0.0
    cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
    cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
    cfg.MODEL.SWIN.APE = False
    cfg.MODEL.SWIN.PATCH_NORM = True
    cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]


def add_our_config(cfg):
    cfg.TEST.SLIDING_WINDOW = False
    cfg.TEST.SLIDING_TILE_SIZE = 224
    cfg.TEST.SLIDING_OVERLAP = 2 / 3.0
    # whether to use dense crf
    cfg.TEST.DENSE_CRF = False
    cfg.DATASETS.SAMPLE_PER_CLASS = -1
    cfg.DATASETS.SAMPLE_SEED = 0
    # embedding head
    cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM = 512
    cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM = 1024
    cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS = 2
    # clip_adapter
    cfg.MODEL.CLIP_ADAPTER = CN()
    cfg.MODEL.CLIP_ADAPTER.TYPE = "maskformer"
    cfg.MODEL.CLIP_ADAPTER.TEXT_TEMPLATES = "vild"
    # for predefined
    cfg.MODEL.CLIP_ADAPTER.PREDEFINED_PROMPT_TEMPLATES = ["a photo of a {}."]
    # for learnable prompt
    cfg.MODEL.CLIP_ADAPTER.PROMPT_CHECKPOINT = ""
    cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME = "ViT-B/16"
    cfg.MODEL.CLIP_ADAPTER.MASK_FILL = "mean"
    cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO = 1.0
    cfg.MODEL.CLIP_ADAPTER.MASK_THR = 0.4
    cfg.MODEL.CLIP_ADAPTER.MASK_MATTING = False
    cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED = True
    cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE = True
    cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT = 0.7
    # for mask prompt]
    cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH = 3
    cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD = False

    # for MaPLe
    cfg.MODEL.MAPLE = CN()
    cfg.MODEL.SAVE_DIR = os.getenv("MODEL_SAVE_DIR")
    cfg.MODEL.MAPLE.N_CTX = 2
    cfg.MODEL.MAPLE.CTX_INIT = "a photo of a"
    cfg.MODEL.MAPLE.INPUT_SIZE = (224, 224)
    cfg.MODEL.MAPLE.PROMPT_DEPTH = 9
    # ADE:
    # cfg.MODEL.MAPLE.DIR = f"{cfg.MODEL.SAVE_DIR}/ade20k_150/MaPLe/vit_l14_c2_ep5_batch8_2ctx_cross_datasets_0shots/seed1"
    # Scannet++:
    # cfg.MODEL.MAPLE.DIR = f"{cfg.MODEL.SAVE_DIR}/scannetpp/MaPLe/vit_b16_c2_ep5_batch4_2ctx_cross_datasets_0shots/seed1"
    # KITTI-360:
    cfg.MODEL.MAPLE.DIR = f"{cfg.MODEL.SAVE_DIR}/kitti360/MaPLe/vit_l14_c2_ep5_batch8_2ctx_cross_datasets_0shots/seed1"
    cfg.MODEL.MAPLE.LOAD_EPOCH = 5    

    # for RPO
    cfg.MODEL.RPO = CN()
    cfg.MODEL.RPO.N_CTX = 24
    cfg.MODEL.RPO.CTX_INIT = "a photo of a _."
    cfg.MODEL.RPO.INPUT_SIZE = (224, 224)
    # ADE:
    # cfg.MODEL.RPO.DIR = "./RPO/output/rpo/base2new/train_base/ade20k_150/shots_0/RPO/main_vitl14/seed1"
    # Scannet++:
    # cfg.MODEL.RPO.DIR = "./RPO/output/rpo/base2new/train_base/scannetpp_negative/shots_0/RPO/main_vitl14/seed12"
    # KITTI-360:
    # cfg.MODEL.RPO.DIR = "./RPO/output/rpo/base2new/train_base/kitti360/shots_0/RPO/main_vitl14/seed1"
    cfg.MODEL.RPO.DIR = "" # Set model dir based on the dataset
    cfg.MODEL.RPO.LOAD_EPOCH = 5

    # for OPENDAS
    cfg.MODEL.OPENDAS = CN()
    cfg.MODEL.OPENDAS.N_CTX_TEXT = 4
    cfg.MODEL.OPENDAS.N_CTX_VISION = 8
    cfg.MODEL.OPENDAS.CTX_INIT = "a photo of a"
    cfg.MODEL.OPENDAS.INPUT_SIZE = (224, 224)   
    
    # For OPENDAS - ADE (uncomment this):
    cfg.MODEL.OPENDAS.DIR = f"{cfg.MODEL.SAVE_DIR}/ade20k_150_negative/OpenDASBasic/vit_l14_c2_ep10_batch16_12+8ctx_use_both_losses_0shots/seed1"
    cfg.MODEL.OPENDAS.LOAD_EPOCH = 12
    cfg.MODEL.OPENDAS.PROMPT_DEPTH_TEXT = 12
    cfg.MODEL.OPENDAS.PROMPT_DEPTH_VISION = 12
    cfg.MODEL.OPENDAS.N_CTX_TEXT = 8
    cfg.MODEL.OPENDAS.N_CTX_VISION = 12
    
    # For OPENDAS -KITTI-360 (uncomment this):
    # cfg.MODEL.OPENDAS.DIR = f"{cfg.MODEL.SAVE_DIR}/kitti360_negative/OpenDAS/vit_l14_c2_ep10_batch16_2+2ctx_use_both_losses_0shots/seed428"
    # cfg.MODEL.OPENDAS.LOAD_EPOCH = 12
    # cfg.MODEL.OPENDAS.PROMPT_DEPTH_VISION = 12
    # cfg.MODEL.OPENDAS.PROMPT_DEPTH_TEXT = 12
    
    # For OPENDAS - Scannet++:
    # cfg.MODEL.OPENDAS.DIR = f"{cfg.MODEL.SAVE_DIR}/scannetpp_similar_negative_v2/OpenDAS/vit_l14_c2_ep10_batch16_2+2ctx_d24_use_both_losses_0shots/seed429"    
    # cfg.MODEL.OPENDAS.LOAD_EPOCH = 8
    # cfg.MODEL.OPENDAS.PROMPT_DEPTH_VISION = 24
    # cfg.MODEL.OPENDAS.PROMPT_DEPTH_TEXT = 24

    # For CoCoOp
    cfg.MODEL.COCOOP = CN()
    cfg.MODEL.COCOOP.N_CTX = 4
    cfg.MODEL.COCOOP.CTX_INIT = "a photo of a"
    cfg.MODEL.COCOOP.DIR = f""
    cfg.MODEL.COCOOP.LOAD_EPOCH = 4

    # wandb
    cfg.WANDB = CN()
    cfg.WANDB.PROJECT = "open_vocab_seg"
    cfg.WANDB.NAME = None


def add_ovseg_config(cfg):
    """
    Add config for open_vocab_seg.
    """
    add_mask_former_default_config(cfg)
    add_our_config(cfg)
