from yacs.config import CfgNode as CN

_C = CN()
_C.SEED = 12345
_C.OUTPUT_DIR = "output"
_C.LOAD_TUNED = True

# Data
_C.DATA = CN()
_C.DATA.PATH = "./src/data/"
_C.DATA.TASK = "TP"
_C.DATA.DATASET_NAME = "sdd"
_C.DATA.OBSERVE_LENGTH = 8
_C.DATA.PREDICT_LENGTH = 12
_C.DATA.SKIP = 1
_C.DATA.BATCH_SIZE = 128
_C.DATA.NUM_WORKERS = 8
_C.DATA.NORMALIZED = False

# TP
_C.DATA.TP = CN()
_C.DATA.TP.STATE = "state_pva"
_C.DATA.TP.PRED_STATE = "state_p"
_C.DATA.TP.ACCEPT_NAN = False

# Model
_C.MODEL = CN()
_C.MODEL.ENCODER_TYPE = "trajectron"
_C.MODEL.TYPE = "COPY_LAST"
_C.MODEL.FLOW = CN()
_C.MODEL.FLOW.ARCHITECTURE = "realNVP"
_C.MODEL.FLOW.N_BLOCKS = 3
_C.MODEL.FLOW.N_HIDDEN = 2
_C.MODEL.FLOW.HIDDEN_SIZE = 64
_C.MODEL.FLOW.CONDITIONING_LENGTH = 16

# Solver
_C.SOLVER = CN()
_C.SOLVER.OPTIMIZER = "adam"
_C.SOLVER.LR = 5e-3
_C.SOLVER.ITER = 100
_C.SOLVER.SAVE_EVERY = 10
_C.SOLVER.USE_SCHEDULER = False
_C.SOLVER.VALIDATION = True
_C.SOLVER.WEIGHT_DECAY = 0.0
_C.SOLVER.DEQUANTIZE = False

# MGF
_C.MGF = CN()
_C.MGF.ENABLE = True
_C.MGF.CLUSTER_N = 7
_C.MGF.VAR_INIT = 0.7
_C.MGF.VAR_LEARNABLE = True
_C.MGF.POST_CLUSTER = 500
