from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from yacs.config import CfgNode as CN


_C = CN()

# ----- BASIC SETTINGS -----
_C.RAND_SEED = 42
_C.NAME = "default"
_C.OUTPUT_DIR = "/home/cifar10/output"
_C.VALID_STEP = 5
_C.SAVE_STEP = 5
_C.SHOW_STEP = 20
_C.PIN_MEMORY = True
_C.INPUT_SIZE = (224, 224)
_C.COLOR_SPACE = "RGB"
_C.RESUME_MODEL = ""
_C.RESUME_MODE = "all"
_C.EVAL_MODE = False
_C.CLASSES_NUM = 100


# ----- DATASET BUILDER -----
_C.DATASET = CN()
_C.DATASET.GENERATE_CAM_BASED_DATASET = False
_C.DATASET.USE_CAM_BASED_DATASET = False
_C.DATASET.CAM_DATA_JSON_SAVE_PATH = ''
_C.DATASET.CAM_DATA_SAVE_PATH = ''
_C.DATASET.CAM_NUMBER_THRES = 0

_C.DATASET.DATASET = "IMBALANCEDCIFAR10"
_C.DATASET.TRAIN_JSON = ""
_C.DATASET.VALID_JSON = ""
_C.DATASET.IMBALANCECIFAR = CN()
_C.DATASET.IMBALANCECIFAR.RATIO = 0.01
_C.DATASET.IMBALANCECIFAR.RANDOM_SEED = 0
_C.DATASET.AUGMIX = CN()
_C.DATASET.AUGMIX.all_ops = False
_C.DATASET.AUGMIX.randaug = False
_C.DATASET.AUGMIX.aug_severity = 3
_C.DATASET.AUGMIX.width = 3
_C.DATASET.AUGMIX.depth = -1
_C.DATASET.AUGMIX.alpha = 1.0
_C.DATASET.use_cuda = False
_C.DATASET.CUDA = CN()
_C.DATASET.CUDA.accept_rate = .6
_C.DATASET.CUDA.update_epoch = 1
_C.DATASET.CUDA.num_test = 10  # Curriculum Test

_C.NETWORK = CN()
_C.NETWORK.PRETRAINED_MODEL = ''
_C.NETWORK.LOAD_BACKBONE_ONLY = False
_C.NETWORK.MA_MODEL_ALPHA = 0.999
_C.NETWORK.MOCO = False
_C.NETWORK.MOCO_K = 65536
_C.NETWORK.MOCO_DIM = 65536
_C.NETWORK.MOCO_T = 0.07
# ----- BACKBONE BUILDER -----
_C.BACKBONE = CN()
_C.BACKBONE.TYPE = "res50"
_C.BACKBONE.MULTI_NETWORK_TYPE = ['res32_cifar']
_C.BACKBONE.FREEZE = False
_C.BACKBONE.PRETRAINED_MODEL = ""
_C.BACKBONE.SHARE_LEVEL = 1

# ----- MODULE BUILDER -----
_C.MODULE = CN()
_C.MODULE.TYPE = "GAP"


_C.DROPOUT = False

# ----- CLASSIFIER BUILDER -----
_C.CLASSIFIER = CN()
_C.CLASSIFIER.TYPE = "FC"
_C.CLASSIFIER.COS_SCALE = 16
_C.CLASSIFIER.SEMI_TYPE = "mlp"
_C.CLASSIFIER.MULTI_NETWORK_TYPE = ['FC']
_C.CLASSIFIER.BIAS = True
_C.CLASSIFIER.NUM = 0

# ----- LOSS BUILDER -----

_C.LOSS = CN()
_C.LOSS.LOSS_TYPE = "CrossEntropy"
_C.LOSS.HCM_N = 5
_C.LOSS.CON_RATIO = 0.0
_C.LOSS.HCM_RATIO = 0.0
_C.LOSS.CE_RATIO = 0.0
_C.LOSS.forward = "forward_org"


_C.LOSS.MULTI_CLASIIFIER_LOSS = CN()
_C.LOSS.MULTI_CLASIIFIER_LOSS.DIVERSITY_FACTOR = 0.0
_C.LOSS.MULTI_CLASIIFIER_LOSS.DIVERSITY_FACTOR_HCM = 0.0


# ----- TRAIN BUILDER -----
_C.TRAIN = CN()
_C.TRAIN.BATCH_SIZE = 32
_C.TRAIN.MAX_EPOCH = 60
_C.TRAIN.SHUFFLE = True
_C.TRAIN.NUM_WORKERS = 8
_C.TRAIN.TENSORBOARD = CN()
_C.TRAIN.TENSORBOARD.ENABLE = True

_C.TRAIN.COMBINER = CN()
_C.TRAIN.COMBINER.TYPE = (
    "default"
)
_C.TRAIN.COMBINER.ALPHA = 1.0

_C.TRAIN.TWO_STAGE = CN()
_C.TRAIN.TWO_STAGE.DRW = False
_C.TRAIN.TWO_STAGE.DRS = False
_C.TRAIN.TWO_STAGE.START_EPOCH = 1

# ----- SAMPLER BUILDER -----
_C.TRAIN.SAMPLER = CN()
_C.TRAIN.SAMPLER.TYPE = "default"
_C.TRAIN.SAMPLER.MULTI_NETWORK_TYPE = ["default"]

_C.TRAIN.SAMPLER.WEIGHTED_SAMPLER = CN()
_C.TRAIN.SAMPLER.WEIGHTED_SAMPLER.TYPE = "balance"
_C.TRAIN.SAMPLER.BBN_SAMPLER = CN()
_C.TRAIN.SAMPLER.BBN_SAMPLER.TYPE = "reverse"


_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.TYPE = "SGD"
_C.TRAIN.OPTIMIZER.BASE_LR = 0.001
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
_C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 1e-4


_C.TRAIN.LR_SCHEDULER = CN()
_C.TRAIN.LR_SCHEDULER.TYPE = "multistep"
_C.TRAIN.LR_SCHEDULER.LR_STEP = [40, 50]
_C.TRAIN.LR_SCHEDULER.LR_FACTOR = 0.1
_C.TRAIN.LR_SCHEDULER.WARM_EPOCH = 5
_C.TRAIN.LR_SCHEDULER.COSINE_DECAY_END = 0
_C.TRAIN.LR_SCHEDULER.ETA_MIN = 1e-4

_C.TRAIN.DISTRIBUTED = False
_C.TRAIN.SYNCBN = False

# testing
_C.TEST = CN()
_C.TEST.BATCH_SIZE = 32
_C.TEST.NUM_WORKERS = 8
_C.TEST.MODEL_FILE = ""

_C.TRANSFORMS = CN()
_C.TRANSFORMS.MULTI_AUG = False
_C.TRANSFORMS.TRAIN_TRANSFORMS = ("random_resized_crop", "random_horizontal_flip")
_C.TRANSFORMS.TEST_TRANSFORMS = ("shorter_resize_for_crop", "center_crop")

_C.TRANSFORMS.PROCESS_DETAIL = CN()
_C.TRANSFORMS.PROCESS_DETAIL.RANDOM_CROP = CN()
_C.TRANSFORMS.PROCESS_DETAIL.RANDOM_CROP.PADDING = 4
_C.TRANSFORMS.PROCESS_DETAIL.RANDOM_RESIZED_CROP = CN()
_C.TRANSFORMS.PROCESS_DETAIL.RANDOM_RESIZED_CROP.SCALE = (0.08, 1.0)
_C.TRANSFORMS.PROCESS_DETAIL.RANDOM_RESIZED_CROP.RATIO = (0.75, 1.333333333)

def update_config(cfg, args):
    cfg.defrost()
    cfg.merge_from_file(args.cfg)
    cfg.merge_from_list(args.opts)

    cfg.freeze()