from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from yacs.config import CfgNode as CN

# configs for HRNet64
HRNET_64 = CN()
HRNET_64.STEM_INPLANES = 64
HRNET_64.FINAL_CONV_KERNEL = 1
HRNET_64.WITH_HEAD = True

HRNET_64.STAGE2 = CN()
HRNET_64.STAGE2.NUM_MODULES = 1
HRNET_64.STAGE2.NUM_BRANCHES = 2
HRNET_64.STAGE2.NUM_BLOCKS = [4, 4]
HRNET_64.STAGE2.NUM_CHANNELS = [64, 128]
HRNET_64.STAGE2.BLOCK = 'BASIC'
HRNET_64.STAGE2.FUSE_METHOD = 'SUM'

HRNET_64.STAGE3 = CN()
HRNET_64.STAGE3.NUM_MODULES = 4
HRNET_64.STAGE3.NUM_BRANCHES = 3
HRNET_64.STAGE3.NUM_BLOCKS = [4, 4, 4]
HRNET_64.STAGE3.NUM_CHANNELS = [64, 128, 256]
HRNET_64.STAGE3.BLOCK = 'BASIC'
HRNET_64.STAGE3.FUSE_METHOD = 'SUM'

HRNET_64.STAGE4 = CN()
HRNET_64.STAGE4.NUM_MODULES = 3
HRNET_64.STAGE4.NUM_BRANCHES = 4
HRNET_64.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
HRNET_64.STAGE4.NUM_CHANNELS = [64, 128, 256, 512]
HRNET_64.STAGE4.BLOCK = 'BASIC'
HRNET_64.STAGE4.FUSE_METHOD = 'SUM'


# configs for HRNet48
HRNET_48 = CN()
HRNET_48.STEM_INPLANES = 64
HRNET_48.FINAL_CONV_KERNEL = 1
HRNET_48.WITH_HEAD = True

HRNET_48.STAGE2 = CN()
HRNET_48.STAGE2.NUM_MODULES = 1
HRNET_48.STAGE2.NUM_BRANCHES = 2
HRNET_48.STAGE2.NUM_BLOCKS = [4, 4]
HRNET_48.STAGE2.NUM_CHANNELS = [48, 96]
HRNET_48.STAGE2.BLOCK = 'BASIC'
HRNET_48.STAGE2.FUSE_METHOD = 'SUM'

HRNET_48.STAGE3 = CN()
HRNET_48.STAGE3.NUM_MODULES = 4
HRNET_48.STAGE3.NUM_BRANCHES = 3
HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4]
HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192]
HRNET_48.STAGE3.BLOCK = 'BASIC'
HRNET_48.STAGE3.FUSE_METHOD = 'SUM'

HRNET_48.STAGE4 = CN()
HRNET_48.STAGE4.NUM_MODULES = 3
HRNET_48.STAGE4.NUM_BRANCHES = 4
HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384]
HRNET_48.STAGE4.BLOCK = 'BASIC'
HRNET_48.STAGE4.FUSE_METHOD = 'SUM'


# configs for HRNet32
HRNET_32 = CN()
HRNET_32.PRETRAINED_LAYERS = ['*']
HRNET_32.STEM_INPLANES = 64
HRNET_32.FINAL_CONV_KERNEL = 1
HRNET_32.WITH_HEAD = True

HRNET_32.STAGE2 = CN()
HRNET_32.STAGE2.NUM_MODULES = 1
HRNET_32.STAGE2.NUM_BRANCHES = 2
HRNET_32.STAGE2.NUM_BLOCKS = [4, 4]
HRNET_32.STAGE2.NUM_CHANNELS = [32, 64]
HRNET_32.STAGE2.BLOCK = 'BASIC'
HRNET_32.STAGE2.FUSE_METHOD = 'SUM'

HRNET_32.STAGE3 = CN()
HRNET_32.STAGE3.NUM_MODULES = 4
HRNET_32.STAGE3.NUM_BRANCHES = 3
HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4]
HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128]
HRNET_32.STAGE3.BLOCK = 'BASIC'
HRNET_32.STAGE3.FUSE_METHOD = 'SUM'

HRNET_32.STAGE4 = CN()
HRNET_32.STAGE4.NUM_MODULES = 3
HRNET_32.STAGE4.NUM_BRANCHES = 4
HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
HRNET_32.STAGE4.BLOCK = 'BASIC'
HRNET_32.STAGE4.FUSE_METHOD = 'SUM'


# configs for HRNet18
HRNET_18 = CN()
HRNET_18.PRETRAINED_LAYERS = ['*']
HRNET_18.STEM_INPLANES = 64
HRNET_18.FINAL_CONV_KERNEL = 1
HRNET_18.WITH_HEAD = True

HRNET_18.STAGE2 = CN()
HRNET_18.STAGE2.NUM_MODULES = 1
HRNET_18.STAGE2.NUM_BRANCHES = 2
HRNET_18.STAGE2.NUM_BLOCKS = [4, 4]
HRNET_18.STAGE2.NUM_CHANNELS = [18, 36]
HRNET_18.STAGE2.BLOCK = 'BASIC'
HRNET_18.STAGE2.FUSE_METHOD = 'SUM'

HRNET_18.STAGE3 = CN()
HRNET_18.STAGE3.NUM_MODULES = 4
HRNET_18.STAGE3.NUM_BRANCHES = 3
HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4]
HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72]
HRNET_18.STAGE3.BLOCK = 'BASIC'
HRNET_18.STAGE3.FUSE_METHOD = 'SUM'

HRNET_18.STAGE4 = CN()
HRNET_18.STAGE4.NUM_MODULES = 3
HRNET_18.STAGE4.NUM_BRANCHES = 4
HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144]
HRNET_18.STAGE4.BLOCK = 'BASIC'
HRNET_18.STAGE4.FUSE_METHOD = 'SUM'

# configs for HRNet2x20
HRNET2X_20 = CN()
HRNET2X_20.FINAL_CONV_KERNEL = 1

HRNET2X_20.STAGE1 = CN()
HRNET2X_20.STAGE1.NUM_MODULES = 1
HRNET2X_20.STAGE1.NUM_BRANCHES = 2
HRNET2X_20.STAGE1.NUM_BLOCKS = [4, 4]
HRNET2X_20.STAGE1.NUM_CHANNELS = [32, 64]
HRNET2X_20.STAGE1.BLOCK = 'BOTTLENECK'
HRNET2X_20.STAGE1.FUSE_METHOD = 'SUM'

HRNET2X_20.STAGE2 = CN()
HRNET2X_20.STAGE2.NUM_MODULES = 1
HRNET2X_20.STAGE2.NUM_BRANCHES = 3
HRNET2X_20.STAGE2.NUM_BLOCKS = [4, 4, 4]
HRNET2X_20.STAGE2.NUM_CHANNELS = [20, 40, 80]
HRNET2X_20.STAGE2.BLOCK = 'BASIC'
HRNET2X_20.STAGE2.FUSE_METHOD = 'SUM'

HRNET2X_20.STAGE3 = CN()
HRNET2X_20.STAGE3.NUM_MODULES = 4
HRNET2X_20.STAGE3.NUM_BRANCHES = 4
HRNET2X_20.STAGE3.NUM_BLOCKS = [4, 4, 4, 4]
HRNET2X_20.STAGE3.NUM_CHANNELS = [20, 40, 80, 160]
HRNET2X_20.STAGE3.BLOCK = 'BASIC'
HRNET2X_20.STAGE3.FUSE_METHOD = 'SUM'

HRNET2X_20.STAGE4 = CN()
HRNET2X_20.STAGE4.NUM_MODULES = 3
HRNET2X_20.STAGE4.NUM_BRANCHES = 5
HRNET2X_20.STAGE4.NUM_BLOCKS = [4, 4, 4, 4, 4]
HRNET2X_20.STAGE4.NUM_CHANNELS = [20, 40, 80, 160, 320]
HRNET2X_20.STAGE4.BLOCK = 'BASIC'
HRNET2X_20.STAGE4.FUSE_METHOD = 'SUM'

MODEL_CONFIGS = {
    'hrnet18': HRNET_18,
    'hrnet32': HRNET_32,
    'hrnet48': HRNET_48,
    'hrnet64': HRNET_64,
    'hrnet2x20': HRNET2X_20,
}
