# ------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# Written by Chunyu Wang (chnuwa@microsoft.com)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import yaml

import numpy as np
from easydict import EasyDict as edict

config = edict()

config.OUTPUT_DIR = 'output'
config.LOG_DIR = 'log'
config.DATA_DIR = ''
config.MODEL = 'multiview_transpose'
config.GPUS = '0,1'
config.WORKERS = 8
config.PRINT_FREQ = 100

# Cudnn related params
config.CUDNN = edict()
config.CUDNN.BENCHMARK = True
config.CUDNN.DETERMINISTIC = False
config.CUDNN.ENABLED = True

# common params for NETWORK
config.NETWORK = edict()
config.NETWORK.NAME = 'pose_hrnet'
config.NETWORK.INIT_WEIGHTS = True
config.NETWORK.PRETRAINED = ''
config.NETWORK.NUM_JOINTS = 17
config.NETWORK.TAG_PER_JOINT = True
config.NETWORK.TARGET_TYPE = 'gaussian'
config.NETWORK.IMAGE_SIZE = [256, 256]  # width * height, ex: 192 * 256
config.NETWORK.HEATMAP_SIZE = [64, 64]  # width * height, ex: 24 * 32
config.NETWORK.PATCH_SIZE = [64, 64]
config.NETWORK.SIGMA = 2
config.NETWORK.HIDDEN_HEATMAP_DIM = -1
config.NETWORK.TRANSFORMER_DEPTH = 2
config.NETWORK.TRANSFORMER_HEADS = 2
config.NETWORK.TRANSFORMER_MLP_RATIO = 2
config.NETWORK.POS_EMBEDDING_TYPE = 'learnable' 
config.NETWORK.DIM = 2
config.NETWORK.MULTI_TRANSFORMER_DEPTH = [12, 12]
config.NETWORK.MULTI_TRANSFORMER_HEADS = [16, 16]
config.NETWORK.MULTI_DIM = [48, 48]
config.NETWORK.INIT = False
config.NETWORK.NUM_BRANCHES = 1
config.NETWORK.BASEconfigHANNEL = 32
config.NETWORK.EXTRA = edict()

# common params for NETWORK
# config.NETWORK = edict()
# config.NETWORK.PRETRAINED = 'models/pytorch/imagenet/resnet50-19c8e357.pth'
# config.NETWORK.NUM_JOINTS = 20
# config.NETWORK.HEATMAP_SIZE = np.array([80, 80])
# config.NETWORK.IMAGE_SIZE = np.array([320, 320])
# config.NETWORK.SIGMA = 2
# config.NETWORK.TARGET_TYPE = 'gaussian'
# config.NETWORK.AGGRE = True

# # Transformer
# config.NETWORK.DIM_MODEL = 256
# config.NETWORK.DIM_FEEDFORWARD = 1024
# config.NETWORK.ENCODER_LAYERS = 3
# config.NETWORK.N_HEAD = 8
# # 2D positional encoding
# config.NETWORK.POS_EMBEDDING = 'sine'
# config.NETWORK.ATTENTION_ACTIVATION = 'relu'

# config.NETWORK.INIT_WEIGHTS = True

# # TransFusion
# config.NETWORK.FUSION = True            # Fuse 2 views or not
# config.NETWORK.POS_EMB_3D = 'geometry'      # 3D position embedding type: none, learnable, geometry
# config.NETWORK.REG_HEAD = True
# config.NETWORK.GAMMA = 10

# # pose_resnet related params
# config.NETWORK.EXTRA = edict()
# config.NETWORK.EXTRA.NUM_LAYERS = 50
# config.NETWORK.EXTRA.DECONV_WITH_BIAS = False
# config.NETWORK.EXTRA.FINAL_CONV_KERNEL = 1
# config.NETWORK.EXTRA.NUM_DECONV_FILTERS = 1

config.LOSS = edict()
config.LOSS.USE_TARGET_WEIGHT = True

# DATASET related params
config.DATASET = edict()
config.DATASET.ROOT = '../data/h36m/'
config.DATASET.TRAIN_DATASET = 'mixed_dataset'
config.DATASET.TEST_DATASET = 'multi_view_h36m'
config.DATASET.TRAIN_SUBSET = 'train'
config.DATASET.TEST_SUBSET = 'validation'
config.DATASET.ROOTIDX = 0
config.DATASET.DATA_FORMAT = 'jpg'
config.DATASET.BBOX = 2000
config.DATASET.CROP = True
config.DATASET.WITH_DAMAGE = True

# training data augmentation
config.DATASET.SCALE_FACTOR = 0
config.DATASET.ROT_FACTOR = 0

# train
config.TRAIN = edict()
config.TRAIN.LR_FACTOR = 0.1
config.TRAIN.LR_STEP = [90, 110]
config.TRAIN.LR = 0.001

config.TRAIN.OPTIMIZER = 'adam'
config.TRAIN.MOMENTUM = 0.9
config.TRAIN.WD = 0.0001
config.TRAIN.NESTEROV = False
config.TRAIN.GAMMA1 = 0.99
config.TRAIN.GAMMA2 = 0.0

config.TRAIN.BEGIN_EPOCH = 0
config.TRAIN.END_EPOCH = 140

config.TRAIN.RESUME = False

config.TRAIN.BATCH_SIZE = 8
config.TRAIN.SHUFFLE = True

# testing
config.TEST = edict()
config.TEST.BATCH_SIZE = 8
config.TEST.STATE = ''
config.TEST.POST_PROCESS = False
config.TEST.SHIFT_HEATMAP = False
config.TEST.USE_GT_BBOX = False
config.TEST.IMAGE_THRE = 0.1
config.TEST.NMS_THRE = 0.6
config.TEST.OKS_THRE = 0.5
config.TEST.IN_VIS_THRE = 0.0
config.TEST.BBOX_FILE = ''
config.TEST.BBOX_THRE = 1.0
config.TEST.MATCH_IOU_THRE = 0.3
config.TEST.DETECTOR = 'fpn_dcn'
config.TEST.DETECTOR_DIR = ''
config.TEST.MODEL_FILE = ''
config.TEST.HEATMAP_LOCATION_FILE = 'predicted_heatmaps.h5'

# debug
config.DEBUG = edict()
config.DEBUG.DEBUG = True
config.DEBUG.SAVE_BATCH_IMAGES_GT = True
config.DEBUG.SAVE_BATCH_IMAGES_PRED = True
config.DEBUG.SAVE_HEATMAPS_GT = True
config.DEBUG.SAVE_HEATMAPS_PRED = True

# pictorial structure
config.PICT_STRUCT = edict()
config.PICT_STRUCT.FIRST_NBINS = 16
config.PICT_STRUCT.PAIRWISE_FILE = ''
config.PICT_STRUCT.RECUR_NBINS = 2
config.PICT_STRUCT.RECUR_DEPTH = 10
config.PICT_STRUCT.LIMB_LENGTH_TOLERANCE = 150
config.PICT_STRUCT.GRID_SIZE = 2000
config.PICT_STRUCT.DEBUG = False
config.PICT_STRUCT.TEST_PAIRWISE = False
config.PICT_STRUCT.SHOW_ORIIMG = False
config.PICT_STRUCT.SHOW_CROPIMG = False
config.PICT_STRUCT.SHOW_HEATIMG = False


def _update_dict(k, v):
    if k == 'DATASET':
        if 'MEAN' in v and v['MEAN']:
            v['MEAN'] = np.array(
                [eval(x) if isinstance(x, str) else x for x in v['MEAN']])
        if 'STD' in v and v['STD']:
            v['STD'] = np.array(
                [eval(x) if isinstance(x, str) else x for x in v['STD']])
    if k == 'NETWORK':
        if 'HEATMAP_SIZE' in v:
            if isinstance(v['HEATMAP_SIZE'], int):
                v['HEATMAP_SIZE'] = np.array(
                    [v['HEATMAP_SIZE'], v['HEATMAP_SIZE']])
            else:
                v['HEATMAP_SIZE'] = np.array(v['HEATMAP_SIZE'])
        if 'IMAGE_SIZE' in v:
            if isinstance(v['IMAGE_SIZE'], int):
                v['IMAGE_SIZE'] = np.array([v['IMAGE_SIZE'], v['IMAGE_SIZE']])
            else:
                v['IMAGE_SIZE'] = np.array(v['IMAGE_SIZE'])
    for vk, vv in v.items():
        if vk in config[k]:
            config[k][vk] = vv
        else:
            raise ValueError("{}.{} not exist in config.py".format(k, vk))


def update_config(config_file):
    exp_config = None
    with open(config_file) as f:
        exp_config = edict(yaml.load(f, Loader=yaml.FullLoader))
        for k, v in exp_config.items():
            if k in config:
                if isinstance(v, dict):
                    _update_dict(k, v)
                else:
                    if k == 'SCALES':
                        config[k][0] = (tuple(v))
                    else:
                        config[k] = v
            else:
                raise ValueError("{} not exist in config.py".format(k))


def gen_config(config_file):
    cfg = dict(config)
    for k, v in cfg.items():
        if isinstance(v, edict):
            cfg[k] = dict(v)

    with open(config_file, 'w') as f:
        yaml.dump(dict(cfg), f, default_flow_style=False)


def update_dir(model_dir, log_dir, data_dir):
    if model_dir:
        config.OUTPUT_DIR = model_dir

    if log_dir:
        config.LOG_DIR = log_dir

    if data_dir:
        config.DATA_DIR = data_dir

    config.DATASET.ROOT = os.path.join(config.DATA_DIR, config.DATASET.ROOT)

    config.TEST.BBOX_FILE = os.path.join(config.DATA_DIR, config.TEST.BBOX_FILE)

    config.NETWORK.PRETRAINED = os.path.join(config.DATA_DIR,
                                             config.NETWORK.PRETRAINED)


def get_model_name(cfg):
    name = '{model}_{num_layers}'.format(
        model=cfg.MODEL, num_layers=999)
    # deconv_suffix = ''.join(
    #     'd{}'.format(num_filters)
    #     for num_filters in cfg.NETWORK.EXTRA.NUM_DECONV_FILTERS)
    deconv_suffix = 'd1'
    full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
        height=cfg.NETWORK.IMAGE_SIZE[1],
        width=cfg.NETWORK.IMAGE_SIZE[0],
        name=name,
        deconv_suffix=deconv_suffix)

    return name, full_name


if __name__ == '__main__':
    import sys
    gen_config(sys.argv[1])

