from yacs.config import CfgNode as CN

# -----------------------------------------------------------------------------
# Convention about Training / Test specific parameters
# -----------------------------------------------------------------------------
# Whenever an argument can be either used for training or for testing, the
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
# or _TEST for a test-specific parameter.
# For example, the number of images during training will be
# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
# IMAGES_PER_BATCH_TEST

# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------

_C = CN()

_C.MODEL = CN()
# Using cuda or cpu for training
_C.MODEL.DEVICE = 'cuda'
# ID number of GPU
_C.MODEL.DEVICE_ID = '0'
# Name of backbone
_C.MODEL.NAME = 'dino'
# Path to pretrained model of backbone
_C.MODEL.PRETRAIN_PATH = ''
_C.MODEL.PRETRAIN_PROJ_PATH = ''
_C.MODEL.DIST_TRAIN = False
# Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
# Options: 'imagenet' , 'self' , 'finetune'
_C.MODEL.PRETRAIN_CHOICE = 'imagenet'
_C.MODEL.DROP_PATH = 0.1
# The loss type of metric loss
_C.MODEL.Transformer_TYPE = 'None'
_C.MODEL.STRIDE_SIZE = [16, 16]
# one branch or three branch for BLOCK Pattern
_C.MODEL.BLOCK_PATTERN = 'normal'
# Feature Updating Momentum
_C.MODEL.FEAT_MOMEN = 0.95
# Prototypes Updating Momentum
_C.MODEL.PROTO_MOMEN = 0.9
# Prototypes Filtering Threshold (Z-score)
_C.MODEL.PROTO_THRES = 0.75
# Number of prototypes
_C.MODEL.PROTO_NUM = 8
_C.MODEL.TASK_TYPE = 'UNINCD'
_C.MODEL.UNINCD_STAGE = 'pretrain'
# If use all patches
_C.MODEL.ALL_PATCHES = False
# All patches pooling layer size -> patch num in one dim
_C.MODEL.POOL_SIZE = 4

# -----------------------------------------------------------------------------
# LOSS
# -----------------------------------------------------------------------------
_C.LOSS = CN()
# if Use Distillation loss
_C.LOSS.USE_DISTI_LOSS = True
# If Use SupCL loss
_C.LOSS.USE_SCL_LOSS = True
# If Use UnCL loss
_C.LOSS.USE_UNCL_LOSS = True
# If Use Labeled-Unlabeled CL loss with pseudo labels
_C.LOSS.USE_LUCL_LOSS = True
# If Use Novel-class CL loss with pseudo labels
_C.LOSS.USE_NCL_LOSS = True
# weights of labedled prototypical loss
_C.LOSS.W_PRO_L = 5.
# Weights of prototypical loss
_C.LOSS.W_PRO = 2.
# Weights of SE loss
_C.LOSS.W_SE = 0.2
# Weights of Unlabeled CE loss
_C.LOSS.W_UCE = 0.1
# Weights of SUPCON loss
_C.LOSS.W_SCL = 0.9
# Weights of UNCON loss
_C.LOSS.W_SSCL = 0.1
# Weights of Novel-class CL with Pseudo labels loss
_C.LOSS.W_NCL = 0.8
# INFO_LOGITS TEMP
_C.LOSS.TEMP_INFO = 1.0
# if use MIM loss (Mutual INformation Maximum)
_C.LOSS.MIM_L = False
_C.LOSS.MIM_U = False
# if use Sparse-Diverse Regularization Loss
_C.LOSS.SDR = False
_C.LOSS.SDR_KL_W = 9.5e-4
_C.LOSS.KLDU_W = 1e-1

# -----------------------------------------------------------------------------
# INPUT
# -----------------------------------------------------------------------------
_C.INPUT = CN()
# Size of the image during training
_C.INPUT.SIZE_TRAIN = [384, 128]
# Size of the image during test
_C.INPUT.SIZE_TEST = [384, 128]
_C.INPUT.SIZE_CROP = [224, 224]

# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------
_C.DATASETS = CN()
# List of the dataset names for training, as present in paths_catalog.py
_C.DATASETS.NAMES = ('Cifar100')
_C.DATASETS.CIFAR_10_ROOT = ''
_C.DATASETS.CIFAR_100_ROOT = ''
_C.DATASETS.IMAGENET_100_ROOT = ''
_C.DATASETS.PROP_TRAIN_LABELS = 0.6

# -----------------------------------------------------------------------------
# DataLoader
# -----------------------------------------------------------------------------
_C.DATALOADER = CN()
# Number of data loading threads
_C.DATALOADER.NUM_WORKERS = 8
# Number of instance for one batch
_C.DATALOADER.NUM_INSTANCE = 4
# Contrastive Views
_C.DATALOADER.N_VIEWS = 2

# ---------------------------------------------------------------------------- #
# Solver
# ---------------------------------------------------------------------------- #
_C.SOLVER = CN()
# Name of optimizer
_C.SOLVER.OPTIMIZER_NAME = "Adam"
# Number of max epoches
_C.SOLVER.MAX_EPOCHS = 100
# Base learning rate
_C.SOLVER.BASE_LR = 3e-4
_C.SOLVER.SEED = 1234
# Dataset, Memory Updating Epoch
_C.SOLVER.UPDATE_MAX = 5
# Unlabeled prototypes updating max epoch
_C.SOLVER.UPDATE_UP_MAX = 5
# Settings of weight decay
_C.SOLVER.WEIGHT_DECAY = 0.0005
# AdamW
_C.SOLVER.BETA1 = 0.9
_C.SOLVER.BETA2 = 0.999
# epoch number of saving checkpoints
_C.SOLVER.CHECKPOINT_PERIOD = 10
# iteration of display training log
_C.SOLVER.LOG_PERIOD = 100
# epoch number of validation
_C.SOLVER.EVAL_PERIOD = 10
# Number of images per batch
_C.SOLVER.IMS_PER_BATCH = 4
# use the pseudo label to filter the source-target pair, and select the right match pairs
_C.SOLVER.WITH_PSEUDO_LABEL_FILTER = False
# Max running iter of KMeans
_C.SOLVER.KMEANS_MAX_ITER = 100

# ---------------------------------------------------------------------------- #
# TEST
# ---------------------------------------------------------------------------- #
_C.TEST = CN()
# Number of images per batch during test
_C.TEST.IMS_PER_BATCH = 128
# Path to trained model
_C.TEST.WEIGHT = ""
# Whether calculate the eval score option: 'True', 'False'
_C.TEST.EVAL = False

# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
# Path to checkpoint and saved log of trained model
_C.OUTPUT_DIR = ""

# ---------------------------------------------------------------------------- #
# Debug options
# ---------------------------------------------------------------------------- #
_C.DEBUG = ""
_C.DEBUG = CN()
_C.DEBUG.IS_DEBUG = 0
_C.DEBUG.DS_RATIO = 2e-3
_C.DEBUG.SAVE_FILEPATH = ''
_C.DEBUG.LOAD_FILEPATH = ''
_C.DEBUG.FEAT_SAVEPATH = ''
