from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from yacs.config import CfgNode as CN

_C = CN()

# ----- BASIC SETTINGS -----
_C.NAME = "MCFM_default"
_C.OUTPUT_DIR = "./output"
_C.VALID_STEP = 20
_C.SAVE_STEP = 20
_C.SHOW_STEP = 100
_C.INPUT_SIZE = (32, 32)
_C.COLOR_SPACE = "RGB"
_C.CPU_MODE = False
_C.use_best_model = False
_C.PRETRAINED_MODEL = ""
_C.availabel_cudas = ""
_C.use_current_task_for_distill = True
_C.multi_centroid_classify = False
_C.use_base_half = False
_C.first_task_mix = True
_C.pre_current_loss_balance = False
_C.use_IB = False
_C.IB_alpha = 1000.00
_C.plus_mix_cls = True
_C.mix_cls_alpha = 1.
_C.rate = 1.
_C.use_mix_cls = True
_C.beta = 0.000001
_C.use_weight = True
_C.re_mix = False
# ----- DATASET BUILDER -----
_C.DATASET = CN()
_C.DATASET.dataset_name = "CIFAR100"  # mnist, mnist28, CIFAR10, CIFAR100, imagenet, svhn
_C.DATASET.dataset = "Torchvision_Datasets_Split"
_C.DATASET.data_json_file = ""
_C.DATASET.data_root = "./datasets"
_C.DATASET.all_classes = 100
_C.DATASET.all_tasks = 10
_C.DATASET.split_seed = 0
_C.DATASET.val_length = 0
_C.DATASET.use_svhn_extra = True

# ----- Mixup -----
_C.Mixup = CN()
_C.Mixup.mixup_alpha1 = 1.
_C.Mixup.mixup_alpha2 = 1.
_C.Mixup.all = False
_C.Mixup.mix_balance = True

# ----- exemplar_manager -----
_C.exemplar_manager = CN()
_C.exemplar_manager.store_original_imgs = True
_C.exemplar_manager.memory_budget = 2000
_C.exemplar_manager.mng_approach = "herding"
_C.exemplar_manager.norm_exemplars = True
_C.exemplar_manager.centroid_order = "herding"
_C.exemplar_manager.fixed_exemplar_num = -1

# ----- BACKBONE BUILDER -----
_C.BACKBONE = CN()
_C.BACKBONE.TYPE = "res32_512"
_C.BACKBONE.PRETRAINED_BACKBONE = ""

# ----- MODULE BUILDER -----
_C.MODULE = CN()
_C.MODULE.TYPE = "GAP"

# ----- resume -----
_C.RESUME = CN()
_C.RESUME.use_resume = False
_C.RESUME.resumed_file = ""
_C.RESUME.resumed_model_path = ""
_C.RESUME.resumed_pre_tasks_model = ""

# ----- CLASSIFIER BUILDER -----
_C.CLASSIFIER = CN()
_C.CLASSIFIER.TYPE = "FC"
_C.CLASSIFIER.BIAS = True
_C.CLASSIFIER.LOSS_FACTOR = 1.


# ----- DISTILL -----
_C.DISTILL = CN()
_C.DISTILL.ENABLE = False
_C.DISTILL.LOSS_FACTOR = 1.
_C.DISTILL.softmax_sigmoid = 0
_C.DISTILL.TEMP = 2.

# ----- LOSS BUILDER -----
_C.LOSS = CN()
_C.LOSS.LOSS_TYPE = "CrossEntropy"

# ----- TRAIN BUILDER -----
_C.TRAIN = CN()
_C.TRAIN.BATCH_SIZE = 128
_C.TRAIN.MAX_EPOCH = 90
_C.TRAIN.IB_EPOCH = 60
_C.TRAIN.SHUFFLE = True
_C.TRAIN.NUM_WORKERS = 4
_C.TRAIN.TENSORBOARD = CN()
_C.TRAIN.TENSORBOARD.ENABLE = True
_C.TRAIN.SUM_GRAD = False


_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.TYPE = "SGD"
_C.TRAIN.OPTIMIZER.BASE_LR = 0.001
_C.TRAIN.OPTIMIZER.IB_BASE_LR = 0.01
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
_C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 1e-4

_C.TRAIN.LR_SCHEDULER = CN()
_C.TRAIN.LR_SCHEDULER.TYPE = "multistep"
_C.TRAIN.LR_SCHEDULER.LR_STEP = [30, 60]
_C.TRAIN.LR_SCHEDULER.IB_LR_STEP = [90, 120]
_C.TRAIN.LR_SCHEDULER.LR_FACTOR = 0.1
_C.TRAIN.LR_SCHEDULER.WARM_EPOCH = 5
_C.TRAIN.LR_SCHEDULER.COSINE_DECAY_END = 0

# ----- CLASSIFIER BUILDER -----
_C.CLASSIFIER = CN()
_C.CLASSIFIER.TYPE = "FC"
_C.CLASSIFIER.BIAS = True
_C.CLASSIFIER.LOSS_FACTOR = 1.

_C.CLASSIFIER.NECK = CN()
_C.CLASSIFIER.NECK.ENABLE = False
_C.CLASSIFIER.NECK.TYPE = 'Linear'
_C.CLASSIFIER.NECK.NUM_FEATURES = 2048
_C.CLASSIFIER.NECK.NUM_OUT = 128
_C.CLASSIFIER.NECK.HIDDEN_DIM = 512
_C.CLASSIFIER.NECK.MARGIN = 1.0
_C.CLASSIFIER.NECK.WEIGHT_INTER_LOSS = False
_C.CLASSIFIER.NECK.WEIGHT_INTRA_LOSS = False
_C.CLASSIFIER.NECK.INTER_DISTANCE = True
_C.CLASSIFIER.NECK.INTRA_DISTANCE = True
_C.CLASSIFIER.NECK.LOSS_FACTOR = 0.5
_C.CLASSIFIER.NECK.distance_loss = False
