import typing as T

from yacs.config import CfgNode as CN

_C: AutoConfig

class Datamodule(CN):
    BATCH_SIZE: int
    EVAL_BATCH_SIZE_MULTIPLIER: float
    NUM_WORKERS: int
    PIN_MEMORY: bool
    FEATURE_EXTRACTOR_MODE: bool

class Dataset(CN):
    RESOLUTION: T.Sequence
    PADDING: T.Sequence
    TIME_SERIES_LENGTH: int
    CLAMP_VALUE: int
    IMAGE_FMT: str
    VIDEO_FRAMES: int
    RANDOM_FRAMES: bool
    CACHE_DIR: str
    SUBJECT_LIST: T.Sequence
    ROIS: T.Sequence
    NAME: str
    ROOT: str
    DARK_POSTFIX: str

class Position_encoding(CN):
    MAX_STEPS: int
    FEATURES: int
    PERIODS: int

class Sd(CN):
    MLP_DIM: int
    MLP_DEPTH: int

class Backbone(CN):
    NAME: str
    CACHE_DIR: str
    DISABLE_BN: bool
    BN_MOMENTUM: float
    LAYERS: T.Sequence
    PRETRAINED: bool
    FREEZE: bool
    SD: Sd

class Conv_head(CN):
    USE: bool
    NAME: str
    KERNELS: T.Sequence
    LAST_KERNELS: T.Sequence
    MAX_DIM: int
    KERNEL_SIZE: int
    DEPTH: int
    WIDTH: int
    BN: bool
    LN: bool
    CONV1X1: bool
    REDUCE_DIM: bool
    SKIP_CONNECTION: bool

class Pool_head(CN):
    USE: bool
    NAME: str

class Pe(CN):
    USE: bool
    MAX_STEPS: int
    PERIODS: int
    FEATURES: int

class Image_shifter(CN):
    USE: bool
    IN_LAYER: str
    WIDTH: int
    DEPTH: int
    PE: Pe

class Simple_conv(CN):
    DEPTH: int
    KERNEL_SIZE: int

class Neck(CN):
    NAME: str
    CONV_HEAD: Conv_head
    POOL_HEAD: Pool_head
    IMAGE_SHIFTER: Image_shifter
    CONV_TYPE: str
    CONCAT_BEFORE_CONV: bool
    CONCAT_LATENT_RESOLUTION: T.Sequence
    DIM: int
    REDUCE_DIM: bool
    BN: bool
    SIMPLE_CONV: Simple_conv

class Neuron_projector(CN):
    SEPARATE_LAYERS: bool
    DEPTH: int
    WIDTH: int
    NUM_NEURON_LATENT: int
    MU_SCALE: float
    SIGMA_SCALE: float
    USE_CONSTANT_SIGMA: bool
    CONSTANT_SIGMA: float
    BATCH_NORM: bool

class Neuron_shifter(CN):
    USE: bool
    NUM_REPEAT: int
    DEPTH: int
    WIDTH: int

class Layer_gate(CN):
    USE: bool
    DEPTH: int
    WIDTH: int
    MEAN: str
    SKIP: bool

class Head(CN):
    BOTTLENECK_DIM: int

class Model(CN):
    BACKBONE: Backbone
    NECK: Neck
    MAX_TRAIN_VOXELS: int
    NEURON_PROJECTOR: Neuron_projector
    NEURON_SHIFTER: Neuron_shifter
    LAYER_GATE: Layer_gate
    HEAD: Head

class Sync(CN):
    USE: bool
    STAGE: str
    SKIP_EPOCHS: int
    EMA_BETA: float
    EMA_BIAS_CORRECTION: bool
    UPDATE_RULE: str
    EXP_SCALE: float
    EXP_SHIFT: float
    LOG_SHIFT: float
    EMA_KEY: str

class Anneal(CN):
    T: int

class Dark(CN):
    USE: bool
    MAX_EPOCH: int
    IGNORE_OTHER_ROIS: bool
    GT_ROIS: T.Sequence
    GT_SCALE_UP_COEF: float
    IGNORE_GT: bool
    ANNEAL: Anneal

class Loss(CN):
    NAME: str
    SMOOTH_L1_BETA: float
    SUBJECT_PREFIX: T.Sequence
    SUBJECT_WEIGHT: T.Sequence
    SYNC: Sync
    DARK: Dark

class Scheduler(CN):
    T_INITIAL: int
    T_MULT: float
    CYCLE_DECAY: float
    CYCLE_LIMIT: int
    WARMUP_T: int
    K_DECAY: float
    LR_MIN: float
    LR_MIN_WARMUP: float

class Optimizer(CN):
    NAME: str
    LR: float
    FINETUNE_BACKBONE_LR_RATIO: float
    NEURON_PROJECTOR_LR_RATIO: float
    WEIGHT_DECAY: float
    BACKBONE_WEIGHT_DECAY: float
    NECK_WEIGHT_DECAY: float
    NEURON_PROJECTOR_WEIGHT_DECAY: float
    LAYER_GATE_WEIGHT_DECAY: float
    VOXEL_WEIGHT_DECAY: float
    GATE_REGULARIZER: float
    MU_REGULARIZER_PDIST: float
    MU_REGULARIZER_MCENTER: float
    MU_REGULARIZER_PCENTER: float
    X_SHIFT_SMOOTH_REGULARIZER: float
    X_SHIFT_ZERO_REGULARIZER: float
    P_MU_SHIFT_REGULARIZER: float
    LR_DECAY_RATE: T.Sequence
    LR_DECAY_STEP: T.Sequence
    WARMUP_STEPS: int
    SCHEDULER: Scheduler

class Stage_2(CN):
    FIT_TO_VALIDATION: bool

class Backbone_1(CN):
    UN_FREEZE_AT_EPOCH: int
    INITIAL_RATIO_LR: float
    LR_MULTIPLY_EFFICIENT: float
    SHOULD_ALIGN: bool
    TRAIN_BN: bool
    VERBOSE: bool

class Early_stop(CN):
    PATIENCE: int
    SUBJECT: str

class Checkpoint(CN):
    SAVE_TOP_K: int
    REMOVE: bool
    LOAD_BEST_ON_VAL: bool
    LOAD_BEST_ON_END: bool

class Logger(CN): {}

class Callbacks(CN):
    BACKBONE: Backbone_1
    EARLY_STOP: Early_stop
    CHECKPOINT: Checkpoint
    SAVE_OUTPUT: bool
    LOGGER: Logger

class Trainer(CN):
    DEVICES: int
    PRECISION: int
    GRADIENT_CLIP_VAL: float
    MAX_EPOCHS: int
    STAGE_2_MAX_EPOCHS: int
    STAGE_2_LR: float
    STAGE_2_WD: float
    STAGE_2_EMA: bool
    STAGE_2_EMA_BETA: float
    MAX_STEPS: int
    ACCUMULATE_GRAD_BATCHES: int
    VAL_CHECK_INTERVAL: float
    LIMIT_TRAIN_BATCHES: float
    LIMIT_VAL_BATCHES: float
    LOG_TRAIN_N_STEPS: int
    CALLBACKS: Callbacks

class Model_soup(CN):
    USE: bool
    RECIPE: str
    GREEDY_TARGET: str

class Finetune(CN):
    SOURCE: str
    USE_LINEAR: bool
    SOUP: str
    SOUP_TARGET: str
    TOP_N: int
    TRAIN_SHARED: bool

class Analysis(CN):
    SAVE_LAST_LINEAR_LAYER: bool
    TRANSFER: bool
    SAVE_NEURON_LOCATION: bool
    DRAW_NEURON_LOCATION: bool

class AutoConfig(CN):
    DESCRIPTION: str
    DATAMODULE: Datamodule
    DATASET: Dataset
    POSITION_ENCODING: Position_encoding
    MODEL: Model
    LOSS: Loss
    OPTIMIZER: Optimizer
    STAGE_2: Stage_2
    TRAINER: Trainer
    MODEL_SOUP: Model_soup
    STAGE: str
    FINETUNE: Finetune
    RESULTS_DIR: str
    ANALYSIS: Analysis
