import os

from yacs.config import CfgNode as CN


_C = CN()

_C.OUTPUT_DIR = ''
_C.LOG_DIR = ''
_C.DATA_DIR = ''
_C.GPUS = (0,)
_C.WORKERS = 4
_C.PRINT_FREQ = 20
_C.AUTO_RESUME = False
_C.PIN_MEMORY = True
_C.RANK = 0
_C.USE_GPU = True
_C.VERSION = 'v0'
_C.LOCAL = False

# Cudnn related params
_C.CUDNN = CN()
_C.CUDNN.BENCHMARK = True
_C.CUDNN.DETERMINISTIC = False
_C.CUDNN.ENABLED = True

# common params for NETWORK
_C.MODEL = CN()
_C.MODEL.NAME = 'pose_hrnet'
_C.MODEL.INIT_WEIGHTS = True
_C.MODEL.PRETRAINED = ''
_C.MODEL.PRETRAINED_AUGMENTER = ''
_C.MODEL.NUM_JOINTS = 17
_C.MODEL.TAG_PER_JOINT = True
_C.MODEL.TARGET_TYPE = 'gaussian'
_C.MODEL.IMAGE_SIZE = [256, 256]  # width * height, ex: 192 * 256
_C.MODEL.HEATMAP_SIZE = [64, 64]  # width * height, ex: 24 * 32
_C.MODEL.SIGMA = 2
_C.MODEL.EXTRA = CN(new_allowed=True)
_C.MODEL.BACKBONE = 'resnet'

_C.MODEL.RN_NAME = 'augmenter'
_C.MODEL.TPEN = False
_C.MODEL.TRN = False
_C.MODEL.KP_CLASS = False
_C.MODEL.KP_CLASS_NUM = 99
_C.MODEL.NUM_FEAT = 352
_C.MODEL.KP_EMB = 64
_C.MODEL.NUM_INTER_FEAT = 0
_C.MODEL.TUNE_HM = False
_C.MODEL.COLLECT_FEAT = False

_C.LOSS = CN()
_C.LOSS.USE_OHKM = False
_C.LOSS.TOPK = 8
_C.LOSS.USE_TARGET_WEIGHT = True
_C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False
# Losses for semi-supervised
_C.LOSS.LIST = []
_C.LOSS.SUPERVISED = False
_C.LOSS.CONSISTENCY = False
_C.LOSS.RECONSTRUCTION = False
_C.LOSS.RECONSTRUCTION_CONS = False
_C.LOSS.KP_CLASS = False
_C.LOSS.KP_CLASS_CONSISTENCY = False

_C.LOSS.SUPERVISED_WEIGHT = 1.
_C.LOSS.CONSISTENCY_WEIGHT = 1.
_C.LOSS.RECONSTRUCTION_WEIGHT = 1.
_C.LOSS.RECONSTRUCTION_CONS_WEIGHT = 1.
_C.LOSS.KP_CLASS_WEIGHT = 1.
_C.LOSS.KP_SIMILARITY_WEIGHT = 1.
_C.LOSS.KP_CLASS_CONSISTENCY_WEIGHT = 1.

_C.LOSS.CONSISTENCY_TYPE = 'mse'
_C.LOSS.RECONSTRUCTION_TYPE = 'perceptual'

# DATASET related params
_C.DATASET = CN()
_C.DATASET.ROOT = ''
_C.DATASET.DATASET = 'mpii'
_C.DATASET.TRAIN_SET = 'train'
_C.DATASET.TEST_SET = 'valid'
_C.DATASET.DATA_FORMAT = 'jpg'
_C.DATASET.HYBRID_JOINTS_TYPE = ''
_C.DATASET.SELECT_DATA = False

# Additional params
_C.DATASET.CENTER_SCALE = False
_C.DATASET.LEGEND = []
_C.DATASET.SYMM_LDMARKS = []
_C.DATASET.SEMANTIC_KP_LABELS = []

# training data augmentation
_C.DATASET.FLIP = True
_C.DATASET.FLIP_PROB = 0.5
_C.DATASET.SCALE_FACTOR = 0.25
_C.DATASET.ROT_FACTOR = 30
_C.DATASET.PROB_HALF_BODY = 0.0
_C.DATASET.NUM_JOINTS_HALF_BODY = 8
_C.DATASET.COLOR_RGB = True
_C.DATASET.HEADBOXES = False

# Semi-supervised learning
_C.DATASET.LABELLED_ONLY = False
_C.DATASET.UNLABELLED_ONLY = False
_C.DATASET.LABELS_SPLIT_FILE = ''

# train
_C.TRAIN = CN()

_C.TRAIN.LR_FACTOR = 0.1
_C.TRAIN.LR_STEP = [90, 110]
_C.TRAIN.LR = 0.001

_C.TRAIN.OPTIMIZER = 'adam'
_C.TRAIN.MOMENTUM = 0.9
_C.TRAIN.WD = 0.0001
_C.TRAIN.NESTEROV = False
_C.TRAIN.GAMMA1 = 0.99
_C.TRAIN.GAMMA2 = 0.0

_C.TRAIN.BEGIN_EPOCH = 0
_C.TRAIN.END_EPOCH = 140

_C.TRAIN.RESUME = False
_C.TRAIN.CHECKPOINT = ''

_C.TRAIN.BS = 32
_C.TRAIN.SHUFFLE = True
_C.TRAIN.BS_LABELLED = 0
_C.TRAIN.UNLABELLED_PERCENTAGE = 1.
_C.TRAIN.USE_GT_HM_REC = False
_C.TRAIN.AUG_FUNC = 'rotation'  # or 'perspective'
_C.TRAIN.AUG_ROT = 45
_C.TRAIN.AUG_VAR = 0.15

# testing
_C.TEST = CN()

# size of images for each device
_C.TEST.BS = 32
# Test Model Epoch
_C.TEST.FLIP_TEST = False
_C.TEST.POST_PROCESS = False
_C.TEST.SHIFT_HEATMAP = False

_C.TEST.USE_GT_BBOX = False
_C.TEST.NO_GT_LABELS = False
_C.TEST.PRED_THRESHOLD = 0.1

# nms
_C.TEST.IMAGE_THRE = 0.1
_C.TEST.NMS_THRE = 0.6
_C.TEST.SOFT_NMS = False
_C.TEST.OKS_THRE = 0.5
_C.TEST.IN_VIS_THRE = 0.0
_C.TEST.COCO_BBOX_FILE = ''
_C.TEST.BBOX_THRE = 1.0
_C.TEST.MODEL_FILE = ''

# debug
_C.DEBUG = CN()
_C.DEBUG.DEBUG = False
_C.DEBUG.SAVE_BATCH_IMAGES_GT = False
_C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
_C.DEBUG.SAVE_HEATMAPS_GT = False
_C.DEBUG.SAVE_HEATMAPS_PRED = False


def update_config(cfg, args):
    cfg.defrost()
    cfg.merge_from_file(args.cfg)
    cfg.merge_from_list(args.opts)

    if args.modelDir:
        cfg.OUTPUT_DIR = args.modelDir

    if args.logDir:
        cfg.LOG_DIR = args.logDir

    if args.dataDir:
        cfg.DATA_DIR = args.dataDir

    cfg.DATASET.ROOT = os.path.join(
        cfg.DATA_DIR, cfg.DATASET.ROOT
    )

    cfg.MODEL.PRETRAINED = os.path.join(
        cfg.DATA_DIR, cfg.MODEL.PRETRAINED
    )

    if cfg.TEST.MODEL_FILE:
        cfg.TEST.MODEL_FILE = os.path.join(
            cfg.DATA_DIR, cfg.TEST.MODEL_FILE
        )

    cfg.freeze()


if __name__ == '__main__':
    import sys
    with open(sys.argv[1], 'w') as f:
        print(_C, file=f)
