#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) BaseDetection, Inc. and its affiliates. All Rights Reserved

from .base_config import BaseConfig

_config_dict = dict(
    # Model state parameters
    OVERIDE_CFG_DIR = "",
    DUMP_INTERMEDITE=False,
    DUMP_RATE=dict(TRAIN=0.,TEST=0),

    MODEL=dict(
        LOAD_PROPOSALS=False,
        MASK_ON=False,
        KEYPOINT_ON=False,


        BACKBONE=dict(
            # Freeze parameters if FREEZE_AT >= 1
            FREEZE_AT=2,
        ),
        RESNETS=dict(
            NUM_CLASSES=None,
            DEPTH=None,
            # res4 for C4 backbone, res2..5 for FPN backbone
            OUT_FEATURES=["res4"],
            # Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
            NUM_GROUPS=1,
            # Options: FrozenBN, GN, "SyncBN", "BN"
            NORM="FrozenBN",
            ACTIVATION=dict(
                NAME="ReLU",
                INPLACE=True,
            ),
            # Whether init last bn weight of each BasicBlock or BottleneckBlock to 0
            ZERO_INIT_RESIDUAL=False,
            # Baseline width of each group.
            # Scaling this parameters will scale the width of all bottleneck layers.
            WIDTH_PER_GROUP=64,
            # Use True only for the original MSRA ResNet; use False for C2 and Torch models
            STRIDE_IN_1X1=True,
            # Output width of res2. Scaling this parameters will scale the width of all 1x1 convs
            # For R18 and R34, this needs to be set to 64
            RES5_DILATION=1,
            RES2_OUT_CHANNELS=256,
            STEM_OUT_CHANNELS=64,

            # Deep Stem
            DEEP_STEM=False,
        ),
        FPN=dict(
            # Names of the input feature maps to be used by FPN
            # They must have contiguous power of 2 strides
            # e.g., ["res2", "res3", "res4", "res5"]
            IN_FEATURES=[],
            OUT_CHANNELS=256,
            # Options: "" (no norm), "GN"
            NORM="",
            # Types for fusing the FPN top-down and lateral features. Can be either "sum" or "avg"
            FUSE_TYPE="sum",
        ),
        ANCHOR_GENERATOR=dict(
            # NAME="DefaultAnchorGenerator",
            # Anchor sizes (i.e. sqrt of area) in absolute pixels w.r.t. the network input.
            # Format: list[list[int]]. SIZES[i] specifies the list of sizes
            # to use for IN_FEATURES[i]; len(SIZES) == len(IN_FEATURES) must be true,
            # or len(SIZES) == 1 is true and size list SIZES[0] is used for all
            # IN_FEATURES.
            SIZES=[[32, 64, 128, 256, 512]],
            # Anchor aspect ratios. For each area given in `SIZES`, anchors with different aspect
            # ratios are generated by an anchor generator.
            # Format: list[list[int]]. ASPECT_RATIOS[i] specifies the list of aspect ratios
            # to use for IN_FEATURES[i]; len(ASPECT_RATIOS) == len(IN_FEATURES) must be true,
            # or len(ASPECT_RATIOS) == 1 is true and aspect ratio list ASPECT_RATIOS[0] is used
            # for all IN_FEATURES.
            ASPECT_RATIOS=[[0.5, 1.0, 2.0]],
            # Anchor angles.
            # list[float], the angle in degrees, for each input feature map.
            # ANGLES[i] specifies the list of angles for IN_FEATURES[i].
            ANGLES=[[-90, 0, 90]],
            # Relative offset between the center of the first anchor and the top-left corner of img
            # Units: fraction of feature map stride (e.g., 0.5 means half stride)
            # Allowed values are floats in [0, 1) range inclusive.
            # Recommended value is 0.5, although it is not expected to affect model accuracy.
            OFFSET=0.0,
        ),
        # NMS type during inference
        # Format: str. (e.g., 'normal' means using normal nms)
        # Allowed values are 'normal', 'softnms-linear', 'softnms-gaussian', 'cluster'
        NMS_TYPE='normal',
    ),
)


class BaseDetectionConfig(BaseConfig):
    def __init__(self, d=None, **kwargs):
        super().__init__(d, **kwargs)
        self._register_configuration(_config_dict)


config = BaseDetectionConfig()
