from yacs.config import CfgNode
from yacs.config import CfgNode as CN
from yacs_stubgen import build_pyi

# YACS overwrite these settings using YAML, all YAML variables MUST BE defined here first
# as this is the master list of ALL attributes.
_C = CN()

# importing default as a global singleton
_C.DESCRIPTION = "Default config"

_C.DATAMODULE = CN()
_C.DATAMODULE.BATCH_SIZE = 128
_C.DATAMODULE.EVAL_BATCH_SIZE_MULTIPLIER = 0.25
_C.DATAMODULE.NUM_WORKERS = 5
_C.DATAMODULE.PIN_MEMORY = True

_C.DATAMODULE.FEATURE_EXTRACTOR_MODE = False

_C.DATASET = CN()
_C.DATASET.RESOLUTION = [224, 224]
_C.DATASET.PADDING = [0, 0, 0, 0]
_C.DATASET.TIME_SERIES_LENGTH = 100
_C.DATASET.CLAMP_VALUE = 20
_C.DATASET.IMAGE_FMT = "JPEG"
_C.DATASET.VIDEO_FRAMES = 10
_C.DATASET.RANDOM_FRAMES = False
_C.DATASET.CACHE_DIR = "/data/cache"
_C.DATASET.SUBJECT_LIST = ["ALL"]
_C.DATASET.ROIS = ["all"]
_C.DATASET.NAME = "ALL"
_C.DATASET.ROOT = "/data/VWE"

_C.DATASET.DARK_POSTFIX = ""

_C.POSITION_ENCODING = CN()
_C.POSITION_ENCODING.MAX_STEPS = 100
_C.POSITION_ENCODING.FEATURES = 32
_C.POSITION_ENCODING.PERIODS = 100

_C.MODEL = CN()

_C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.NAME = "resnet18"
_C.MODEL.BACKBONE.CACHE_DIR = "/data/cache"
_C.MODEL.BACKBONE.DISABLE_BN = False
_C.MODEL.BACKBONE.BN_MOMENTUM = -1.0
_C.MODEL.BACKBONE.LAYERS = ["layer3"]
_C.MODEL.BACKBONE.PRETRAINED = True
_C.MODEL.BACKBONE.FREEZE = True

_C.MODEL.BACKBONE.SD = CN()
_C.MODEL.BACKBONE.SD.MLP_DIM = 768
_C.MODEL.BACKBONE.SD.MLP_DEPTH = 2


_C.MODEL.NECK = CN()
_C.MODEL.NECK.NAME = "TopyNeck"

_C.MODEL.NECK.CONV_HEAD = CN()
_C.MODEL.NECK.CONV_HEAD.USE = True
_C.MODEL.NECK.CONV_HEAD.NAME = "ConvHead"
_C.MODEL.NECK.CONV_HEAD.KERNELS = [5, 5]
_C.MODEL.NECK.CONV_HEAD.LAST_KERNELS = [3, 3]
_C.MODEL.NECK.CONV_HEAD.MAX_DIM = 1280
_C.MODEL.NECK.CONV_HEAD.KERNEL_SIZE = 3
_C.MODEL.NECK.CONV_HEAD.DEPTH = 2
_C.MODEL.NECK.CONV_HEAD.WIDTH = 256
_C.MODEL.NECK.CONV_HEAD.BN = True
_C.MODEL.NECK.CONV_HEAD.LN = False
_C.MODEL.NECK.CONV_HEAD.CONV1X1 = False
_C.MODEL.NECK.CONV_HEAD.REDUCE_DIM = False
_C.MODEL.NECK.CONV_HEAD.SKIP_CONNECTION = False

_C.MODEL.NECK.POOL_HEAD = CN()
_C.MODEL.NECK.POOL_HEAD.USE = False
_C.MODEL.NECK.POOL_HEAD.NAME = "AvgMaxPoolHead"

_C.MODEL.NECK.IMAGE_SHIFTER = CN()
_C.MODEL.NECK.IMAGE_SHIFTER.USE = False
_C.MODEL.NECK.IMAGE_SHIFTER.IN_LAYER = 'layer3'
_C.MODEL.NECK.IMAGE_SHIFTER.WIDTH = 128
_C.MODEL.NECK.IMAGE_SHIFTER.DEPTH = 2
_C.MODEL.NECK.IMAGE_SHIFTER.PE = CN()
_C.MODEL.NECK.IMAGE_SHIFTER.PE.USE = True
_C.MODEL.NECK.IMAGE_SHIFTER.PE.MAX_STEPS = 100
_C.MODEL.NECK.IMAGE_SHIFTER.PE.PERIODS = 100
_C.MODEL.NECK.IMAGE_SHIFTER.PE.FEATURES = 32

# legacy
_C.MODEL.NECK.CONV_TYPE = "simple"
_C.MODEL.NECK.CONCAT_BEFORE_CONV = False
_C.MODEL.NECK.CONCAT_LATENT_RESOLUTION = [28, 28]
_C.MODEL.NECK.DIM = 256
_C.MODEL.NECK.REDUCE_DIM = False
_C.MODEL.NECK.BN = True
_C.MODEL.NECK.SIMPLE_CONV = CN()
_C.MODEL.NECK.SIMPLE_CONV.DEPTH = 4
_C.MODEL.NECK.SIMPLE_CONV.KERNEL_SIZE = 7

_C.MODEL.MAX_TRAIN_VOXELS = 10000

_C.MODEL.NEURON_PROJECTOR = CN()
_C.MODEL.NEURON_PROJECTOR.SEPARATE_LAYERS = False
_C.MODEL.NEURON_PROJECTOR.DEPTH = 3
_C.MODEL.NEURON_PROJECTOR.WIDTH = 64
_C.MODEL.NEURON_PROJECTOR.NUM_NEURON_LATENT = 1
_C.MODEL.NEURON_PROJECTOR.MU_SCALE = 0.9
_C.MODEL.NEURON_PROJECTOR.SIGMA_SCALE = 0.01
_C.MODEL.NEURON_PROJECTOR.USE_CONSTANT_SIGMA = True
_C.MODEL.NEURON_PROJECTOR.CONSTANT_SIGMA = 0.01
_C.MODEL.NEURON_PROJECTOR.BATCH_NORM = False

_C.MODEL.NEURON_SHIFTER = CN()
_C.MODEL.NEURON_SHIFTER.USE = True
_C.MODEL.NEURON_SHIFTER.NUM_REPEAT = 32
_C.MODEL.NEURON_SHIFTER.DEPTH = 3
_C.MODEL.NEURON_SHIFTER.WIDTH = 64


_C.MODEL.LAYER_GATE = CN()
_C.MODEL.LAYER_GATE.USE = True
_C.MODEL.LAYER_GATE.DEPTH = 3
_C.MODEL.LAYER_GATE.WIDTH = 64
_C.MODEL.LAYER_GATE.MEAN = 'mean'  # 'mean' or 'geometric_mean'
_C.MODEL.LAYER_GATE.SKIP = False

# legacy
_C.MODEL.HEAD = CN()
_C.MODEL.HEAD.BOTTLENECK_DIM = 256

_C.LOSS = CN()
_C.LOSS.NAME = "SmoothL1Loss"
_C.LOSS.SMOOTH_L1_BETA = 0.1
_C.LOSS.SUBJECT_PREFIX = [
    "NSD",
    "B5K",
    "HCP",
    "EEG2",
    "ALG",
    "HCP",
    "MEG1",
    "fMRI1",
]
_C.LOSS.SUBJECT_WEIGHT = [
    1.0,
    1.0,
    1.0,
    1.0,
    1.0,
    1.0,
    1.0,
    1.0,
]
_C.LOSS.SYNC = CN()
_C.LOSS.SYNC.USE = True
_C.LOSS.SYNC.STAGE = "VAL"
_C.LOSS.SYNC.SKIP_EPOCHS = 1
_C.LOSS.SYNC.EMA_BETA = 0.9
_C.LOSS.SYNC.EMA_BIAS_CORRECTION = False
_C.LOSS.SYNC.UPDATE_RULE = "raw"
_C.LOSS.SYNC.EXP_SCALE = 1.
_C.LOSS.SYNC.EXP_SHIFT = 0.
_C.LOSS.SYNC.LOG_SHIFT = 10.0
_C.LOSS.SYNC.EMA_KEY = 'running_grad'

_C.LOSS.DARK = CN()
_C.LOSS.DARK.USE = False
_C.LOSS.DARK.MAX_EPOCH = 100
_C.LOSS.DARK.IGNORE_OTHER_ROIS = False
_C.LOSS.DARK.GT_ROIS = ["htroi_1"]
_C.LOSS.DARK.GT_SCALE_UP_COEF = 3.0
_C.LOSS.DARK.IGNORE_GT = False

_C.LOSS.DARK.ANNEAL = CN()
_C.LOSS.DARK.ANNEAL.T = 30

_C.OPTIMIZER = CN()
_C.OPTIMIZER.NAME = "AdaBelief"
_C.OPTIMIZER.LR = 1e-2
_C.OPTIMIZER.FINETUNE_BACKBONE_LR_RATIO = 0.1
_C.OPTIMIZER.NEURON_PROJECTOR_LR_RATIO = 10.0
_C.OPTIMIZER.WEIGHT_DECAY = 1e-4
_C.OPTIMIZER.BACKBONE_WEIGHT_DECAY = 0.0
_C.OPTIMIZER.NECK_WEIGHT_DECAY = 1e-4
_C.OPTIMIZER.NEURON_PROJECTOR_WEIGHT_DECAY = 1e-2
_C.OPTIMIZER.LAYER_GATE_WEIGHT_DECAY = 1e-2
_C.OPTIMIZER.VOXEL_WEIGHT_DECAY = 1e-4
_C.OPTIMIZER.GATE_REGULARIZER = 1e-4
_C.OPTIMIZER.MU_REGULARIZER_PDIST = 0.0
_C.OPTIMIZER.MU_REGULARIZER_MCENTER = 3e-3
_C.OPTIMIZER.MU_REGULARIZER_PCENTER = 1e-4
_C.OPTIMIZER.X_SHIFT_SMOOTH_REGULARIZER = .0
_C.OPTIMIZER.X_SHIFT_ZERO_REGULARIZER = .0
_C.OPTIMIZER.P_MU_SHIFT_REGULARIZER = .0
_C.OPTIMIZER.LR_DECAY_RATE = [1.0]
_C.OPTIMIZER.LR_DECAY_STEP = [1000]
_C.OPTIMIZER.WARMUP_STEPS = 10

_C.STAGE_2 = CN()
_C.STAGE_2.FIT_TO_VALIDATION = False

_C.OPTIMIZER.SCHEDULER = CN()
_C.OPTIMIZER.SCHEDULER.T_INITIAL = 50
_C.OPTIMIZER.SCHEDULER.T_MULT = 1.0
_C.OPTIMIZER.SCHEDULER.CYCLE_DECAY = 0.5
_C.OPTIMIZER.SCHEDULER.CYCLE_LIMIT = 3
_C.OPTIMIZER.SCHEDULER.WARMUP_T = 10
_C.OPTIMIZER.SCHEDULER.K_DECAY = 1.5
_C.OPTIMIZER.SCHEDULER.LR_MIN = 3e-4
_C.OPTIMIZER.SCHEDULER.LR_MIN_WARMUP = 1e-6

_C.TRAINER = CN()
_C.TRAINER.DEVICES = 1
_C.TRAINER.PRECISION = 16
_C.TRAINER.GRADIENT_CLIP_VAL = 0.5
_C.TRAINER.MAX_EPOCHS = 1000
_C.TRAINER.STAGE_2_MAX_EPOCHS = 100
_C.TRAINER.STAGE_2_LR = 1e-2
_C.TRAINER.STAGE_2_WD = 1e-2
_C.TRAINER.STAGE_2_EMA = False
_C.TRAINER.STAGE_2_EMA_BETA = 0.999
_C.TRAINER.MAX_STEPS = -1
_C.TRAINER.ACCUMULATE_GRAD_BATCHES = 1
_C.TRAINER.VAL_CHECK_INTERVAL = 1.0
_C.TRAINER.LIMIT_TRAIN_BATCHES = 0.1
_C.TRAINER.LIMIT_VAL_BATCHES = 1.0
_C.TRAINER.LOG_TRAIN_N_STEPS = 100

_C.TRAINER.CALLBACKS = CN()

_C.TRAINER.CALLBACKS.BACKBONE = CN()
_C.TRAINER.CALLBACKS.BACKBONE.UN_FREEZE_AT_EPOCH = 114514
_C.TRAINER.CALLBACKS.BACKBONE.INITIAL_RATIO_LR = 0.0
_C.TRAINER.CALLBACKS.BACKBONE.LR_MULTIPLY_EFFICIENT = 1.6  # 1.6^5 ~= 10
_C.TRAINER.CALLBACKS.BACKBONE.SHOULD_ALIGN = True
_C.TRAINER.CALLBACKS.BACKBONE.TRAIN_BN = True
_C.TRAINER.CALLBACKS.BACKBONE.VERBOSE = True

_C.TRAINER.CALLBACKS.EARLY_STOP = CN()
_C.TRAINER.CALLBACKS.EARLY_STOP.PATIENCE = 21
_C.TRAINER.CALLBACKS.EARLY_STOP.SUBJECT = "mean"

_C.TRAINER.CALLBACKS.CHECKPOINT = CN()
_C.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K = 1
_C.TRAINER.CALLBACKS.CHECKPOINT.REMOVE = False
_C.TRAINER.CALLBACKS.CHECKPOINT.LOAD_BEST_ON_VAL = False
_C.TRAINER.CALLBACKS.CHECKPOINT.LOAD_BEST_ON_END = True

_C.TRAINER.CALLBACKS.SAVE_OUTPUT = False

_C.MODEL_SOUP = CN()
_C.MODEL_SOUP.USE = True
_C.MODEL_SOUP.RECIPE = 'greedy'
_C.MODEL_SOUP.GREEDY_TARGET = 'heldout'


_C.STAGE = 'pretrain'

_C.FINETUNE = CN()
_C.FINETUNE.SOURCE = 'single'
_C.FINETUNE.USE_LINEAR = True
_C.FINETUNE.SOUP = 'uniform'
_C.FINETUNE.SOUP_TARGET = 'heldout'
_C.FINETUNE.TOP_N = 10
_C.FINETUNE.TRAIN_SHARED = False

_C.TRAINER.CALLBACKS.LOGGER = CN()

_C.RESULTS_DIR = "/data/ray_results/"

_C.ANALYSIS = CN()
_C.ANALYSIS.SAVE_LAST_LINEAR_LAYER = False  # TODO: implement for TopyNeck
_C.ANALYSIS.TRANSFER = False
_C.ANALYSIS.SAVE_NEURON_LOCATION = False
_C.ANALYSIS.DRAW_NEURON_LOCATION = False

# this alias ensure you can import `AutoConfig` and use something like `isinstance`
AutoConfig = CN
# _C is the CfgNode object, "_C" should be its varname correctly
# AutoConfig is an alias of CfgNode, "AutoConfig" should be its varname correctly
build_pyi(_C, __file__, var_name="_C")
