# shared
target: experiments.data_generalization.CNNDataset
normalize: False
path: dataset/cnn_generalization_splits.json
statistics_path: dataset/statistics.pth
input_channels: 3
num_classes: 10
img_shape: [28, 28]
_max_kernel_height: 11
_max_kernel_width: 11
max_kernel_size:
- ${data._max_kernel_height}
- ${data._max_kernel_width}
linear_as_conv: True
flattening_method: repeat_nodes  # repeat_nodes or extra_layer
max_spatial_resolution: 49  # 7x7 feature map size
max_num_hidden_layers: 3
inr_model:
  _target_: nn.probe_features.INRPerLayer
  n_layers: 3
  up_scale: 16
  out_channels: 1

stats:
  weights_mean: [-0.003178653074428439, 0.00825976300984621, 0.005337678827345371, 0.013196256943047047, -0.0034679118543863297]
  weights_std: [0.15346874296665192, 0.06397216022014618, 0.06516087800264359, 0.11304810643196106, 0.22598302364349365]
  biases_mean: [0.14812950789928436, 0.04664045199751854, 0.03937701880931854, 0.028652429580688477, 0.0029597203247249126]
  biases_std: [0.1204405426979065, 0.0870404914021492, 0.11048907786607742, 0.11734256148338318, 0.14044655859470367]

train:
  _target_: ${data.target}
  _recursive_: True
  path: ${data.path}
  split: train
  normalize: ${data.normalize}
  augmentation: True
  statistics_path: ${data.statistics_path}
  max_kernel_size: ${data.max_kernel_size}
  linear_as_conv: ${data.linear_as_conv}
  flattening_method: ${data.flattening_method}
  max_num_hidden_layers: ${data.max_num_hidden_layers}
  # mixup_augmentation:
  #   _target_: src.utils.data.augmentations.ReBasinMixUpAugmentation
  # num_classes: ${data.num_classes}

val:
  _target_: ${data.target}
  path: ${data.path}
  split: val
  normalize: ${data.normalize}
  augmentation: False
  statistics_path: ${data.statistics_path}
  max_kernel_size: ${data.max_kernel_size}
  linear_as_conv: ${data.linear_as_conv}
  flattening_method: ${data.flattening_method}
  max_num_hidden_layers: ${data.max_num_hidden_layers}
  # num_classes: ${data.num_classes}

test:
  _target_: ${data.target}
  path: ${data.path}
  split: test
  normalize: ${data.normalize}
  augmentation: False
  statistics_path: ${data.statistics_path}
  max_kernel_size: ${data.max_kernel_size}
  linear_as_conv: ${data.linear_as_conv}
  flattening_method: ${data.flattening_method}
  max_num_hidden_layers: ${data.max_num_hidden_layers}
  # num_classes: ${data.num_classes}

