from yacs.config import CfgNode as CN

_C = CN()

# -------------------------------------------------------------
# Model options
# -------------------------------------------------------------
_C.MODEL = CN()

# Model architecture
_C.MODEL.ARCH = "unet"

# activation function
_C.MODEL.ACT_FUNC = "softmax"

# activation function
_C.MODEL.REG_ACT_FUNC = "relu"

# The path of pretrained model
_C.MODEL.PRETRAINED = ""

# The number of classes to predict
_C.MODEL.NUM_CLASSES = 8

# The number of classes to predict
_C.MODEL.MULTI_LABEL = False

# The number of input channels
_C.MODEL.INPUT_CHANNELS = 3

# Rate of dropout
_C.MODEL.DROPOUT = -1.0

# -------------------------------------------------------------
# Dataset options
# -------------------------------------------------------------
_C.DATA = CN()

_C.DATA.NAME = "retinal-lesions"

# The root directory of dataset
_C.DATA.DATA_ROOT = ""

# The label values of pixels in the mask
_C.DATA.LABEL_VALUES = [255]

# Available for retinal-lesions dataset.
# If true, convert the data setting to binary classification
_C.DATA.BINARY = False

# If true, get and output the region size as well
_C.DATA.REGION_SIZE = False

# If true, get and output the region size normalized by the area as well
_C.DATA.NORMALIZE_REGION_SIZE = False

# If true, get and output the region number as well
_C.DATA.REGION_NUMBER = False

# The mean value of the raw pixels across the R G B channels.
_C.DATA.MEAN = [0.485, 0.456, 0.406]

# The std value of the raw pixels across the R G B channels.
_C.DATA.STD = [0.229, 0.224, 0.225]

# The target size of image resize
_C.DATA.RESIZE = (512, 512)

# How many subprocesses to use for data loading.
_C.DATA.NUM_WORKERS = 8

# For retianl-lesion-class setting
_C.DATA.CLASS_NAME = "hard_exudate"


# -------------------------------------------------------------
# Optimizer options
# -------------------------------------------------------------
_C.LOSS = CN()

# Name of loss function
_C.LOSS.NAME = "bce_logit"

# The target value that is igored and does not contribute to param optimization
_C.LOSS.IGNORE_INDEX = -100

# For some losses, the background index is required
_C.LOSS.BACKGROUND_INDEX = -1

# Hyper parameter of loss
_C.LOSS.ALPHA = 0.1

# Step size of adjusting hyper weight
# If zero, it won't change the weight during the training
_C.LOSS.ALPHA_STEP_SIZE = 0

# Factor of increasing hyper weight when it triggers adjusting
_C.LOSS.ALPHA_FACTOR = 5

# temperature
_C.LOSS.TEMP = 20.0

# Label smoothing for soft bce loss
_C.LOSS.LABEL_SMOOTHING = 0.1

_C.LOSS.CLASS_WEIGHTS = []


# -------------------------------------------------------------
# Optimizer options
# -------------------------------------------------------------
_C.SOLVER = CN()

# Optimization method
_C.SOLVER.OPTIMIZING_METHOD = "adam"

# Base learning rate
_C.SOLVER.BASE_LR = 0.1

# Minimal learning rate during scheduling
_C.SOLVER.MIN_LR = 1e-5

# Learning rate policy
_C.SOLVER.LR_POLICY = "reduce_on_plateau"

# Available for ReduceLROnPlateau
_C.SOLVER.FACTOR = 0.5

_C.SOLVER.PATIENCE = 6

_C.SOLVER.REDUCE_MODE = "min"

# Available for cosine policy
_C.SOLVER.COSINE_END_LR = 0.0

# Momentum.
_C.SOLVER.MOMENTUM = 0.9

# Momentum dampening.
_C.SOLVER.DAMPENING = 0.0

# Nesterov momentum.
_C.SOLVER.NESTEROV = True

# Exponential decay factor.
_C.SOLVER.GAMMA = 0.1

# L2 regularization.
_C.SOLVER.WEIGHT_DECAY = 1e-4

# Step size for 'exp' and 'cos' policies (in epochs).
_C.SOLVER.STEP_SIZE = 10

# Maximal number of epochs.
_C.SOLVER.MAX_EPOCH = 300

# Number of warmup epochs.
_C.SOLVER.WARMUP_EPOCH = 0

# Momentum.
_C.SOLVER.MOMENTUM = 0.9

# -------------------------------------------------------------
# Training options
# -------------------------------------------------------------
_C.TRAIN = CN()

# Total mini-batch size.
_C.TRAIN.BATCH_SIZE = 4

# train data list path (A relative path to _C.DATA.DATA_ROOT or an absoulte path)
_C.TRAIN.DATA_PATH = ""

# Evaluate model on test data every eval period epochs.
_C.TRAIN.EVAL_PERIOD = 1

# Save model checkpoint every checkpoint period epochs.
_C.TRAIN.CHECKPOINT_PERIOD = 1

# If True, caculate metric (auc/F1/dice/...) in training phase.
# May be very costy due to the large size of traing samples
_C.TRAIN.CALCULATE_METRIC = False

# Resume training from the latest checkpoint in the output directory
_C.TRAIN.AUTO_RESUME = False

# -------------------------------------------------------------
# Validation options
# -------------------------------------------------------------
_C.VAL = CN()

# Total mini-batch size.
_C.VAL.BATCH_SIZE = 4

# Val data list path (A relative path to _C.DATA.DATA_ROOT or an absoulte path)
_C.VAL.DATA_PATH = ""

# -------------------------------------------------------------
# Test options (Only available when running test script)
# -------------------------------------------------------------
_C.TEST = CN()

# Val data list path (A relative path to _C.DATA.DATA_ROOT or an absoulte path)
_C.TEST.DATA_PATH = ""

# Total mini-batch size.
_C.TEST.BATCH_SIZE = 4

# The path of the testing checkpoint file.
# If empty, it will load model indicated in the best_checkpoint file
_C.TEST.CHECKPOINT_PATH = ""

# If True, it will load model indicated in the best_checkpoint file
_C.TEST.BEST_CHECKPOINT = True

# The model to be tested indexed by epoch (start from 1)
_C.TEST.MODEL_EPOCH = 0

# If True, it will save the predicted results into one numpy array file
_C.TEST.SAVE_PREDICTS = False

# -------------------------------------------------------------
# Predict mode
# -------------------------------------------------------------
_C.PREDICT = CN()

# Either : ["batch", "singlescale", "multiscale", "sliding"]
_C.PREDICT.MODE = "batch"

# Triggered in multiscale mode
_C.PREDICT.SCALES = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25]

# If True, applying flip augmentation in the test
_C.PREDICT.FLIP = False


# -----------------------------------------------------------------------------
# Tensorboard Visualization Options
# -----------------------------------------------------------------------------
_C.TENSORBOARD = CN()

# Log to summary writer, this will automatically.
# log loss, lr and metrics during train/eval.
_C.TENSORBOARD.ENABLE = False

# If ture, plot score for each class.
_C.TENSORBOARD.PLOT_CLASS_SCORE = False

# Path of a txt file providing class names
_C.TENSORBOARD.CLASSES_NAMES_PATH = ""

# -------------------------------------------------------------
# Misc options
# -------------------------------------------------------------
# Output basedir.
_C.OUTPUT_DIR = "./tmp"

# Note that non-determinism may still be present due to non-deterministic
# operator implementations in GPU operator libraries.
_C.RNG_SEED = 1

# LOG preriod in iters
_C.LOG_PERIOD = 10

# If True, log the model info.
_C.LOG_MODEL_INFO = True

# The device name
_C.DEVICE = "cuda:0"

# If True, perform test after training
_C.PERFORM_TEST = False

# Threshold for determining the positive segmentaiton results
_C.THRES = 0.5


def get_cfg() -> CN:
    """
    Get a copy of the default configuration.
    """
    return _C.clone()
