
dataset         : "MNIST" 
task0_labels    : [0, 1, 2, 3, 4]                 # type: List[int], options in [0,9]
task1_labels    : [0, 1, 2, 3, 4]                 # type: List[int], options in [0,9]


full_model  : "Linear"
full_kwargs :
  input_size      : 784 # 28*28 = 784 MNSIT
  hidden_size     : ${control.hidden_size}
  num_classes     : 10

teacher0_model  : "Linear"
teacher0_kwargs :
  input_size      : 784 # 28*28 = 784 MNSIT
  hidden_size     : ${control.hidden_size}
  num_classes     : 5

teacher1_model  : "Linear"
teacher1_kwargs :
  input_size      : 784 # 28*28 = 784 MNSIT
  hidden_size     : ${control.hidden_size}
  num_classes     : 5

student_model  : "DoubleHeadLinear"
student_kwargs :
  input_size      : 784 # 28*28 = 784 MNSIT
  hidden_size     : ${control.hidden_size}
  num_classes_per_head : 5

student_equal_heads: False

epochs_teacher  : 50                              # type: int > 0
epochs_t0       : 50                              # type: int > 0
epochs_t1       : 50                              # type: int > 0
BATCH_SIZE      : 64                              # type: int > 0
seed            : ${control.set_seed}             # type: int
learning_rate   : 0.01                            # type: float
momentum        : 0.0                             # type: float
detect_anomaly_flag : False                       # type: bool


control:
  hidden_size   : 10                            # type: int
  hidden_size_2 : 10                            # type: int
  set_seed      : 1                             # type: int
  tasks_key     : null                          # type: str
  alpha         : 0