
########################################################
#   General settings
########################################################

# -- Definitions of datasets --

omniglot: 
    class: 'omniglot'
    data_type: 'image'
    num_channels: 1
    image_size: 28

cars:
    class: 'cars'
    data_type: 'image'
    num_channels: 3
    image_size: 84

cub:
    class: 'cub'
    data_type: 'image'
    num_channels: 3
    image_size: 84

cifarfs:
    class: 'cifarfs'
    data_type: 'image'
    num_channels: 3
    image_size: 32

vgg_flower:
    class: 'vgg_flower'
    data_type: 'image'
    num_channels: 3
    image_size: 32

aircraft:
    class: 'aircraft'
    data_type: 'image'
    num_channels: 3
    image_size: 32

miniimagenet:
    class: 'mini-imagenet'
    data_type: 'image'
    num_channels: 3
    image_size: null

sinusoid:
    class: 'sinusoid'
    num_tasks_train: 300000
    num_tasks_test: 160000

# -- Definitions of networks --

omniglot_mlp:
    class: 'MLP'
    disable_bn: false
    sizes: [256, 128, 64, 64]

wrn28-5:
    class: 'WRN28'
    widen_factor: 5
wrn28-10:
    class: 'WRN28'
    widen_factor: 10

resnet12:
    class: 'ResNet12'

sinusoid-mlp:
    class: 'SinusoidMLP'
    hidden_dim: 40

3mlp:
    class: 'MLP'
    disable_bn: false
    sizes: "3-layers"
4mlp:
    class: 'MLP'
    disable_bn: false
    sizes: "4-layers"
5mlp:
    class: 'MLP'
    disable_bn: false
    sizes: "5-layers"

# -- all options --

__default__: &__default__

    # General Setting
    num_workers: 4
    use_cuda: True
    use_best: false
    output_dir: '__outputs__'
    dataset_dir: '__data__'
    sync_dir: '__sync__'
    checkpoint_epochs: []
    checkpoint_iterations: []
    seed: null
    seed_by_time: false
    dataset_download: true
    num_gpus: 1
    debug_max_iters: null
    load_checkpoint_path: null
    train_augmentation: True

    dataset.config_name: null
    model.config_name: null
    save_best_model: true
    print_train_loss: false
    factor: 1.0

    # Hyperparameters for Training
    iteration: null   # for meta-learning
    validation_freq: null   # for meta-learning
    optimizer: "SGD"
    lr: null
    weight_decay: null
    lr_scheduler: null
    warmup_iters: 0
    sgd_momentum: 0.9
    lr_milestones: null
    multisteplr_gamma: 0.1

    learning_framework: null

    batch_size: 128
    batch_size_eval: 512
    max_train_dataset_size: null
    bn_track_running_stats: False
    bn_affine: True
    bn_momentum: 0.1

    # Hyperparameters for Meta-Learning
    valid_batch_size: 32
    test_batch_size: 1000
    inner_lr: 0.5
    no_update_keywords: []
    train_steps: 1
    train_ways: 5
    train_shots: 1
    test_steps: null # => Same as train_steps
    test_ways: null # => Same as train_ways
    test_shots: null # => Same as train_shots
    first_order: false

    # Hyperparameters for edge-popup
    ignore_params: null
    init_mode: null
    init_mode_linear: null
    init_mode_mask: 'kaiming_uniform'
    init_scale: 1.0
    init_scale_score: 1.0
    init_sparsity: null # option for SparseModule2
    threshold_coeff: null # option for SparseModule3
    print_sparsity: false
    learnable_scale: false
    scale_lr: 0.0
    scale_delta_coeff: 1.0
    aux_optimizer: "AdamW"
    aux_lr: null
    aux_weight_decay: 0.0
    aux_params: []

    const_params_test: false
    const_params_train: false
    const_reset_scale: 1.0

    # Hyperparameters for IteRand
    rerand_freq: null
    rerand_rate: 1.0

    # Hyperparameter Search Setting
    parallel_grid: null
    parallel_command: null


# ----------------------------------
#       ResNet12 + MiniImagenet
# ----------------------------------
miniimagenet_1s5w_resnet12_maml: &miniimagenet_1s5w_resnet12_maml
    <<: *__default__
    dataset.config_name: 'miniimagenet'

    iteration: 30000 # [from BOIL paper/repos] 30000 (ResNet12) / 5000 (ConvNet)
    validation_freq: 1000
    batch_size: 4
    optimizer: "AdamW"
    weight_decay: 0.0
    model.config_name: "resnet12"
    bn_track_running_stats: False

    learning_framework: "MAML"
    valid_batch_size: 200
    inner_lr: 0.5
    train_steps: 1
    train_ways: 5
    train_shots: 1
    test_ways: 5
    test_shots: 1

    parallel_command: "meta_train"
    seed: 1

    parallel_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1,2,3]
miniimagenet_1s5w_resnet12_maml+sgd:
    <<: *miniimagenet_1s5w_resnet12_maml
    optimizer: "SGD"
    lr_scheduler: "CustomCosineLR"
    warmup_iters: 100
    weight_decay: 0.0
    parallel_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1]
miniimagenet_1s5w_resnet12_boil: &miniimagenet_1s5w_resnet12_boil
    <<: *miniimagenet_1s5w_resnet12_maml
    no_update_keywords: ['classifier']
miniimagenet_1s5w_resnet12_anil: &miniimagenet_1s5w_resnet12_anil
    <<: *miniimagenet_1s5w_resnet12_maml
    no_update_keywords: ['layer1', 'layer2', 'layer3', 'layer4']
miniimagenet_1s5w_resnet12_ticket: &miniimagenet_1s5w_resnet12_ticket
    <<: *miniimagenet_1s5w_resnet12_maml
    learning_framework: "MetaTicket"
    ignore_params: ['bn1','bn2','bn3','downsample.1','bias','classifier']
    init_mode: 'kaiming_normal'

    print_sparsity: true

    optimizer: "SGD"
    lr_scheduler: "CustomCosineLR"
    warmup_iters: 100
    weight_decay: 0.0

    parallel_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
miniimagenet_1s5w_resnet12_ticket+boil: &miniimagenet_1s5w_resnet12_ticket-boil
    <<: *miniimagenet_1s5w_resnet12_ticket
    no_update_keywords: ['classifier']
    parallel_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
miniimagenet_1s5w_resnet12_ticket+pruneout:
    <<: *miniimagenet_1s5w_resnet12_ticket
    ignore_params: ['bn1','bn2','bn3','downsample.1','bias']
 
miniimagenet_5s5w_resnet12_maml: &miniimagenet_5s5w_resnet12_maml
    <<: *miniimagenet_1s5w_resnet12_maml
    train_shots: 5
    test_shots: 5
    parallel_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1, 2, 3]
miniimagenet_5s5w_resnet12_boil: &miniimagenet_5s5w_resnet12_boil
    <<: *miniimagenet_1s5w_resnet12_boil
    train_shots: 5
    test_shots: 5
    parallel_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1, 2, 3]
miniimagenet_5s5w_resnet12_anil:
    <<: *miniimagenet_1s5w_resnet12_anil
    train_shots: 5
    test_shots: 5
    parallel_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1, 2, 3]
miniimagenet_5s5w_resnet12_ticket: &miniimagenet_5s5w_resnet12_ticket
    <<: *miniimagenet_1s5w_resnet12_ticket
    train_shots: 5
    test_shots: 5
    parallel_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
miniimagenet_5s5w_resnet12_ticket+boil:
    <<: *miniimagenet_1s5w_resnet12_ticket-boil
    train_shots: 5
    test_shots: 5
    parallel_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
# =================================

# ----------------------------------
#       MiniImagenet + WRN28-5
# ----------------------------------
miniimagenet_5s5w_wrn28-5_maml: &miniimagenet_5s5w_wrn28-5_maml
    <<: *miniimagenet_1s5w_resnet12_maml
    model.config_name: "wrn28-5"
    train_shots: 5
    test_shots: 5
    parallel_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1, 2, 3]
miniimagenet_5s5w_wrn28-5_boil:
    <<: *miniimagenet_5s5w_wrn28-5_maml
    no_update_keywords: ['classifier']
miniimagenet_5s5w_wrn28-5_anil:
    <<: *miniimagenet_5s5w_wrn28-5_maml
    no_update_keywords: ['features']

miniimagenet_5s5w_wrn28-5_ticket: &miniimagenet_5s5w_wrn28-5_ticket
    <<: *miniimagenet_1s5w_resnet12_ticket
    model.config_name: "wrn28-5"
    ignore_params: ['bn1','bn2','bias','classifier']

    train_shots: 5
    test_shots: 5

    parallel_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
miniimagenet_5s5w_wrn28-5_ticket+boil:
    <<: *miniimagenet_5s5w_wrn28-5_ticket
    no_update_keywords: ['classifier']
    parallel_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
# =================================

# ----------------------------------
#       MiniImagenet + WRN28-10
# ----------------------------------
miniimagenet_5s5w_wrn28-10_maml: &miniimagenet_5s5w_wrn28-10_maml
    <<: *miniimagenet_5s5w_wrn28-5_maml
    model.config_name: "wrn28-10"
    train_shots: 5
    test_shots: 5
    parallel_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1, 2, 3]
miniimagenet_5s5w_wrn28-10_boil:
    <<: *miniimagenet_5s5w_wrn28-10_maml
    no_update_keywords: ['classifier']
miniimagenet_5s5w_wrn28-10_anil:
    <<: *miniimagenet_5s5w_wrn28-10_maml
    no_update_keywords: ['features']

miniimagenet_5s5w_wrn28-10_ticket: &miniimagenet_5s5w_wrn28-10_ticket
    <<: *miniimagenet_5s5w_wrn28-5_ticket
    model.config_name: "wrn28-10"
    ignore_params: ['bn1','bn2','bias','classifier']

    train_shots: 5
    test_shots: 5

    parallel_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
miniimagenet_5s5w_wrn28-10_ticket+boil:
    <<: *miniimagenet_5s5w_wrn28-10_ticket
    no_update_keywords: ['classifier']
    parallel_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
miniimagenet_5s5w_wrn28-10_ticket+pruneout:
    <<: *miniimagenet_5s5w_wrn28-10_ticket
    ignore_params: ['bn1','bn2','bias']
miniimagenet_5s5w_wrn28-10_ticket+pruneout+boil:
    <<: *miniimagenet_5s5w_wrn28-10_ticket
    ignore_params: ['bn1','bn2','bias']
    no_update_keywords: ['classifier']
# =================================

# ----------------------------------
#      CIFAR-FS + MLPs
# ----------------------------------
cifarfs_1s5w_3mlp_maml: &cifarfs_1s5w_3mlp_maml
    <<: *__default__
    dataset.config_name: 'cifarfs'
    model.config_name: "3mlp"

    iteration: 30000 # [from BOIL paper/repos] 30000 (ResNet12) / 5000 (ConvNet)
    validation_freq: 1000
    batch_size: 4
    optimizer: "AdamW"
    weight_decay: 0.0
    bn_track_running_stats: False

    learning_framework: "MAML"
    valid_batch_size: 200
    inner_lr: null
    train_steps: 1
    train_ways: 5
    train_shots: 1
    test_ways: 5
    test_shots: 1

    parallel_command: "meta_train"
    seed: 1

    init_mode: 'kaiming_normal'

    parallel_grid:
        lr: [0.001]
        hid1: [8192]
        hid2: [64, 128]
        inner_lr: [0.5]
        seed: [1,2,3]
cifarfs_1s5w_4mlp_maml: &cifarfs_1s5w_4mlp_maml
    <<: *cifarfs_1s5w_3mlp_maml
    model.config_name: "4mlp"
    parallel_grid:
        lr: [0.001]
        hid1: [2048]
        hid2: [2048]
        hid3: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cifarfs_1s5w_5mlp_maml: &cifarfs_1s5w_5mlp_maml
    <<: *cifarfs_1s5w_3mlp_maml
    model.config_name: "5mlp"
    parallel_grid:
        lr: [0.001]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cifarfs_1s5w_5mlp_boil: &cifarfs_1s5w_5mlp_boil
    <<: *cifarfs_1s5w_5mlp_maml
    no_update_keywords: ['classifier']
    parallel_grid:
        lr: [0.001]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cifarfs_1s5w_5mlp_anil: &cifarfs_1s5w_5mlp_anil
    <<: *cifarfs_1s5w_5mlp_maml
    no_update_keywords: ['features']
    parallel_grid:
        lr: [0.001]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cifarfs_1s5w_5mlp_ticket: &cifarfs_1s5w_5mlp_ticket
    <<: *cifarfs_1s5w_5mlp_maml
    learning_framework: "MetaTicket"
    ignore_params: ['classifier','normalize','.bias']
    init_mode: 'kaiming_normal'

    print_sparsity: true

    optimizer: "SGD"
    lr_scheduler: "CustomCosineLR"
    warmup_iters: 100
    weight_decay: 0.0

    parallel_grid:
        lr: [10.0]
        inner_lr: [0.5]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [1.0, 0.0]
        seed: [1, 2, 3]

cifarfs_5s5w_5mlp_maml: &cifarfs_5s5w_5mlp_maml
    <<: *cifarfs_1s5w_5mlp_maml
    train_shots: 5
    test_shots: 5
    parallel_grid:
        lr: [0.001]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cifarfs_5s5w_5mlp_boil: &cifarfs_5s5w_5mlp_boil
    <<: *cifarfs_1s5w_5mlp_boil
    train_shots: 5
    test_shots: 5
    parallel_grid:
        lr: [0.001]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cifarfs_5s5w_5mlp_anil: &cifarfs_5s5w_5mlp_anil
    <<: *cifarfs_1s5w_5mlp_anil
    train_shots: 5
    test_shots: 5
    parallel_grid:
        lr: [0.001]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cifarfs_5s5w_5mlp_ticket: &cifarfs_5s5w_5mlp_ticket
    <<: *cifarfs_1s5w_5mlp_ticket
    train_shots: 5
    test_shots: 5
    scale_lr: 0.0
    parallel_grid:
        lr: [10.0]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        init_mode: ['kaiming_normal']
        #init_mode: ['kaiming_uniform', 'signed_constant']
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
cifarfs_5s5w_5mlp_ticket+boil: &cifarfs_5s5w_5mlp_ticket-boil
    <<: *cifarfs_5s5w_5mlp_ticket
    no_update_keywords: ['classifier']
    parallel_grid:
        lr: [10.0]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        init_mode: ['kaiming_normal']
        #init_mode: ['kaiming_uniform', 'signed_constant']
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
# ==================================

# ----------------------------------
#      VGG-Flower w/ MLP experiments 
# ----------------------------------
vggflower_5s5w_5mlp_maml:
    <<: *cifarfs_5s5w_5mlp_maml
    dataset.config_name: 'vgg_flower'
vggflower_5s5w_5mlp_boil:
    <<: *cifarfs_5s5w_5mlp_boil
    dataset.config_name: 'vgg_flower'
vggflower_5s5w_5mlp_anil:
    <<: *cifarfs_5s5w_5mlp_anil
    dataset.config_name: 'vgg_flower'
vggflower_5s5w_5mlp_ticket:
    <<: *cifarfs_5s5w_5mlp_ticket
    dataset.config_name: 'vgg_flower'
vggflower_5s5w_5mlp_ticket+boil:
    <<: *cifarfs_5s5w_5mlp_ticket-boil
    dataset.config_name: 'vgg_flower'
# ==================================

# ----------------------------------
#      Omniglot w/ MLP experiments 
# ----------------------------------
omniglot_1s5w_mlp_maml: &omniglot_1s5w_mlp_maml
    <<: *__default__
    dataset.config_name: 'omniglot'
    train_ways: 5
    train_shots: 1
    test_ways: null  # => same as train_ways
    test_shots: null  # => same as train_shots
    train_steps: 1
    test_steps: 3  # from MAML paper

    learning_framework: "MAML"

    iteration: 30000
    valid_batch_size: 200
    validation_freq: 1000
    batch_size: 4
    optimizer: "AdamW"
    weight_decay: 0.0
    model.config_name: "omniglot_mlp"
    init_mode: kaiming_normal
    bn_track_running_stats: False

    parallel_command: "meta_train"

    parallel_grid:
        lr: [0.001]
        inner_lr: [0.4]
        factor: [1.0, 2.0, 4.0, 8.0]
        seed: [1,2,3]
omniglot_1s5w_mlp_ticket: &omniglot_1s5w_mlp_ticket
    <<: *omniglot_1s5w_mlp_maml
    learning_framework: "MetaTicket"

    ignore_params: ['classifier','normalize','.bias']
    init_mode: 'kaiming_normal'
    print_sparsity: true
    init_sparsity: 0.0

    optimizer: "SGD"
    lr_scheduler: "CustomCosineLR"
    warmup_iters: 1000
    weight_decay: 0.0

    parallel_grid:
        lr: [10.0]
        inner_lr: [0.4]
        factor: [1.0, 2.0, 4.0, 8.0]
        rerand_freq: [1000]
        rerand_rate: [0.0, 1.0]
        seed: [1,2,3]
omniglot_5s5w_mlp_maml:
    <<: *omniglot_1s5w_mlp_maml
    train_shots: 5
    test_shots: null  # => same as train_shots
omniglot_5s5w_mlp_ticket:
    <<: *omniglot_1s5w_mlp_ticket
    train_shots: 5
    test_shots: null  # => same as train_shots
omniglot_1s20w_mlp_maml:
    <<: *omniglot_1s5w_mlp_maml
    train_ways: 20
    train_shots: 1
    test_ways: null  # => same as train_ways
    test_shots: null  # => same as train_shots
omniglot_1s20w_mlp_ticket:
    <<: *omniglot_1s5w_mlp_ticket
    train_ways: 20
    train_shots: 1
    test_ways: null  # => same as train_ways
    test_shots: null  # => same as train_shots
omniglot_5s20w_mlp_maml:
    <<: *omniglot_1s5w_mlp_maml
    train_ways: 20
    train_shots: 5
    test_ways: null  # => same as train_ways
    test_shots: null  # => same as train_shots
omniglot_5s20w_mlp_ticket:
    <<: *omniglot_1s5w_mlp_ticket
    train_ways: 20
    train_shots: 5
    test_ways: null  # => same as train_ways
    test_shots: null  # => same as train_shots

randexp_omniglot_5s5w_mlp_ticket_const1-const1: &randexp_omniglot_5s5w_mlp_ticket
    <<: *omniglot_1s5w_mlp_ticket
    train_shots: 5
    test_shots: null  # => same as train_shots

    ignore_params: ['normalize','.bias'] # NOTE: only in randexp
    batch_size: 32  # NOTE: only in randexp
    optimizer: "AdamW" # NOTE: only in randexp
    train_steps: 3 # NOTE: only in randexp
    test_steps: 3  # from MAML paper
    lr_scheduler: null
    warmup_iters: null
    weight_decay: 0.0

    rerand_rate: null
    const_params_test: true
    const_params_train: true

    parallel_grid:
        iteration: [60000]
        lr: [0.1]  # NOTE: for train_steps=3
        inner_lr: [0.4]
        factor: [4.0]
        init_sparsity: [0.0]
        const_reset_scale: [1.0]
        seed: [1,2,3]
randexp_omniglot_5s5w_mlp_ticket_randmask-const1:
    <<: *randexp_omniglot_5s5w_mlp_ticket
    validation_freq: 1
    parallel_grid:
        iteration: [1]
        lr: [0.0]
        inner_lr: [0.4]
        factor: [4.0]
        init_sparsity: [0.5]
        const_reset_scale: [1.0]
        seed: [1,2,3]


# ----------------------------------
#      Sinusoid experiments 
# ----------------------------------
sinusoid_5shot_mlp_maml: &sinusoid_5shot_mlp_maml
    <<: *__default__
    iteration: 30000
    validation_freq: 1000
    batch_size: 4
    optimizer: "AdamW"
    weight_decay: 0.0

    learning_framework: "SinusoidMAML"
    dataset.config_name: "sinusoid"
    model.config_name: "sinusoid-mlp"

    sin_amplitude_train: 5.0
    sin_amplitude_test: 5.0
    sin_phase_min_train: 0.0
    sin_phase_max_train: 3.141592
    sin_phase_min_test: 0.0
    sin_phase_max_test: 3.141592

    valid_batch_size: 200
    inner_lr: null
    train_shots: 5
    test_shots: null
    train_steps: 5
    test_steps: null

    parallel_command: "meta_train"

    parallel_grid:
        factor: [1.0]
        lr: [0.001]
        inner_lr: [0.01]
        seed: [1,2,3]
sinusoid_5shot_mlp_ticket: &sinusoid_5shot_mlp_ticket
    <<: *sinusoid_5shot_mlp_maml
    learning_framework: "SinusoidMetaTicket"
    ignore_params: ['hidden1', 'hidden3','.bias']
    init_mode: 'kaiming_normal'

    print_sparsity: true

    optimizer: "SGD"
    lr_scheduler: "CustomCosineLR"
    weight_decay: 0.0
    aux_params: ['hidden1', 'hidden3', 'bias']
    aux_optimizer: 'AdamW'
sinusoid_5shot_mlp_ticket+adam:
    <<: *sinusoid_5shot_mlp_ticket
    optimizer: "AdamW"
    weight_decay: 0.0
    lr_scheduler: null
    parallel_grid:
        factor: [1.0]
        lr: [0.001]
        inner_lr: [0.01]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [1.0]
        aux_lr: [0.001]
        seed: [1,2,3]
# ==================================

# ----------------------------------
#      Cross-domain test
# ----------------------------------

__cross_test_default__: &__cross_test_default__
    # these params overwride the ones in target config
    output_dir: '__outputs__'
    num_gpus: 1
    seed: 1
    use_cuda: true
    use_best: false

    eval_dataset.config_name: null
    target_name: null
    target_grid: null

# ================================
#     miniImagenet + ResNet12
# ================================

cross_miniimagenet2cub_5s5w_resnet12_maml: &cross_miniimagenet2cub_5s5w_resnet12_maml
    <<: *__cross_test_default__
    eval_dataset.config_name: "cub"
    target_name: "miniimagenet_5s5w_resnet12_maml"
    target_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1,2,3]
cross_miniimagenet2cub_5s5w_resnet12_boil:
    <<: *cross_miniimagenet2cub_5s5w_resnet12_maml
    target_name: "miniimagenet_5s5w_resnet12_boil"
cross_miniimagenet2cub_5s5w_resnet12_anil:
    <<: *cross_miniimagenet2cub_5s5w_resnet12_maml
    target_name: "miniimagenet_5s5w_resnet12_anil"
cross_miniimagenet2cub_5s5w_resnet12_ticket: &cross_miniimagenet2cub_5s5w_resnet12_ticket
    <<: *cross_miniimagenet2cub_5s5w_resnet12_maml
    target_name: "miniimagenet_5s5w_resnet12_ticket"
    target_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
cross_miniimagenet2cub_5s5w_resnet12_ticket+boil:
    <<: *cross_miniimagenet2cub_5s5w_resnet12_maml
    target_name: "miniimagenet_5s5w_resnet12_ticket+boil"
    target_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]

cross_miniimagenet2cars_5s5w_resnet12_maml: &cross_miniimagenet2cars_5s5w_resnet12_maml
    <<: *__cross_test_default__
    eval_dataset.config_name: "cars"
    target_name: "miniimagenet_5s5w_resnet12_maml"
    target_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1,2,3]
cross_miniimagenet2cars_5s5w_resnet12_boil:
    <<: *cross_miniimagenet2cars_5s5w_resnet12_maml
    target_name: "miniimagenet_5s5w_resnet12_boil"
cross_miniimagenet2cars_5s5w_resnet12_anil:
    <<: *cross_miniimagenet2cars_5s5w_resnet12_maml
    target_name: "miniimagenet_5s5w_resnet12_anil"
cross_miniimagenet2cars_5s5w_resnet12_ticket: &cross_miniimagenet2cars_5s5w_resnet12_ticket
    <<: *cross_miniimagenet2cars_5s5w_resnet12_maml
    target_name: "miniimagenet_5s5w_resnet12_ticket"
    target_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
cross_miniimagenet2cars_5s5w_resnet12_ticket+boil:
    <<: *cross_miniimagenet2cars_5s5w_resnet12_maml
    target_name: "miniimagenet_5s5w_resnet12_ticket+boil"
    target_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]

# ================================
#     CIFAR-FS + 5MLP
# ================================

cross_cifarfs2aircraft_5s5w_5mlp_maml: &cross_cifarfs2aircraft_5s5w_5mlp_maml
    <<: *__cross_test_default__
    eval_dataset.config_name: "aircraft"
    target_name: "cifarfs_5s5w_5mlp_maml"
    target_grid:
        lr: [0.0001]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cross_cifarfs2aircraft_5s5w_5mlp_boil:
    <<: *cross_cifarfs2aircraft_5s5w_5mlp_maml
    target_name: "cifarfs_5s5w_5mlp_boil"
cross_cifarfs2aircraft_5s5w_5mlp_anil:
    <<: *cross_cifarfs2aircraft_5s5w_5mlp_maml
    target_name: "cifarfs_5s5w_5mlp_anil"
cross_cifarfs2aircraft_5s5w_5mlp_ticket:
    <<: *cross_cifarfs2aircraft_5s5w_5mlp_maml
    target_name: "cifarfs_5s5w_5mlp_ticket"
    target_grid:
        lr: [10.0]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        init_mode: ['kaiming_normal']
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
cross_cifarfs2aircraft_5s5w_5mlp_ticket+boil:
    <<: *cross_cifarfs2aircraft_5s5w_5mlp_maml
    target_name: "cifarfs_5s5w_5mlp_ticket+boil"
    target_grid:
        lr: [10.0]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        init_mode: ['kaiming_normal']
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]

cross_cifarfs2vggflower_5s5w_5mlp_maml: &cross_cifarfs2vggflower_5s5w_5mlp_maml
    <<: *__cross_test_default__
    eval_dataset.config_name: "vgg_flower"
    target_name: "cifarfs_5s5w_5mlp_maml"
    target_grid:
        lr: [0.0001]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cross_cifarfs2vggflower_5s5w_5mlp_boil:
    <<: *cross_cifarfs2vggflower_5s5w_5mlp_maml
    target_name: "cifarfs_5s5w_5mlp_boil"
cross_cifarfs2vggflower_5s5w_5mlp_anil:
    <<: *cross_cifarfs2vggflower_5s5w_5mlp_maml
    target_name: "cifarfs_5s5w_5mlp_anil"
cross_cifarfs2vggflower_5s5w_5mlp_ticket:
    <<: *cross_cifarfs2vggflower_5s5w_5mlp_maml
    target_name: "cifarfs_5s5w_5mlp_ticket"
    target_grid:
        lr: [10.0]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        init_mode: ['kaiming_normal']
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
cross_cifarfs2vggflower_5s5w_5mlp_ticket+boil:
    <<: *cross_cifarfs2vggflower_5s5w_5mlp_maml
    target_name: "cifarfs_5s5w_5mlp_ticket+boil"
    target_grid:
        lr: [10.0]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        init_mode: ['kaiming_normal']
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]


cross_vggflower2aircraft_5s5w_5mlp_maml: &cross_vggflower2aircraft_5s5w_5mlp_maml
    <<: *__cross_test_default__
    eval_dataset.config_name: "aircraft"
    target_name: "vggflower_5s5w_5mlp_maml"
    target_grid:
        lr: [0.001]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cross_vggflower2aircraft_5s5w_5mlp_boil:
    <<: *cross_vggflower2aircraft_5s5w_5mlp_maml
    target_name: "vggflower_5s5w_5mlp_boil"
cross_vggflower2aircraft_5s5w_5mlp_anil:
    <<: *cross_vggflower2aircraft_5s5w_5mlp_maml
    target_name: "vggflower_5s5w_5mlp_anil"
cross_vggflower2aircraft_5s5w_5mlp_ticket:
    <<: *cross_vggflower2aircraft_5s5w_5mlp_maml
    target_name: "vggflower_5s5w_5mlp_ticket"
    target_grid:
        lr: [10.0]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        init_mode: ['kaiming_normal']
        #init_mode: ['kaiming_uniform', 'signed_constant']
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
cross_vggflower2aircraft_5s5w_5mlp_ticket+boil:
    <<: *cross_vggflower2aircraft_5s5w_5mlp_maml
    target_name: "vggflower_5s5w_5mlp_ticket+boil"
    target_grid:
        lr: [10.0]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        init_mode: ['kaiming_normal']
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]

cross_vggflower2cifarfs_5s5w_5mlp_maml: &cross_vggflower2cifarfs_5s5w_5mlp_maml
    <<: *__cross_test_default__
    eval_dataset.config_name: "cifarfs"
    target_name: "vggflower_5s5w_5mlp_maml"
    target_grid:
        lr: [0.001]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        seed: [1,2,3]
cross_vggflower2cifarfs_5s5w_5mlp_boil:
    <<: *cross_vggflower2cifarfs_5s5w_5mlp_maml
    target_name: "vggflower_5s5w_5mlp_boil"
cross_vggflower2cifarfs_5s5w_5mlp_anil:
    <<: *cross_vggflower2cifarfs_5s5w_5mlp_maml
    target_name: "vggflower_5s5w_5mlp_anil"
cross_vggflower2cifarfs_5s5w_5mlp_ticket:
    <<: *cross_vggflower2cifarfs_5s5w_5mlp_maml
    target_name: "vggflower_5s5w_5mlp_ticket"
    target_grid:
        lr: [10.0]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        init_mode: ['kaiming_normal']
        #init_mode: ['kaiming_uniform', 'signed_constant']
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
cross_vggflower2cifarfs_5s5w_5mlp_ticket+boil:
    <<: *cross_vggflower2cifarfs_5s5w_5mlp_maml
    target_name: "vggflower_5s5w_5mlp_ticket+boil"
    target_grid:
        lr: [10.0]
        hid1: [1024]
        hid2: [512]
        hid3: [256]
        hid4: [128]
        inner_lr: [0.5]
        init_mode: ['kaiming_normal']
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
# ==================================

# ================================
#     miniImagenet + WRN28-10
# ================================

cross_miniimagenet2cub_5s5w_wrn28-10_maml: &cross_miniimagenet2cub_5s5w_wrn28-10_maml
    <<: *__cross_test_default__
    eval_dataset.config_name: "cub"
    target_name: "miniimagenet_5s5w_wrn28-10_maml"
    target_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1, 2, 3]
cross_miniimagenet2cub_5s5w_wrn28-10_boil:
    <<: *cross_miniimagenet2cub_5s5w_wrn28-10_maml
    target_name: "miniimagenet_5s5w_wrn28-10_boil"
cross_miniimagenet2cub_5s5w_wrn28-10_anil:
    <<: *cross_miniimagenet2cub_5s5w_wrn28-10_maml
    target_name: "miniimagenet_5s5w_wrn28-10_anil"
cross_miniimagenet2cub_5s5w_wrn28-10_ticket:
    <<: *cross_miniimagenet2cub_5s5w_wrn28-10_maml
    target_name: "miniimagenet_5s5w_wrn28-10_ticket"
    target_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
cross_miniimagenet2cub_5s5w_wrn28-10_ticket+boil:
    <<: *cross_miniimagenet2cub_5s5w_wrn28-10_maml
    target_name: "miniimagenet_5s5w_wrn28-10_ticket+boil"
    target_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]

cross_miniimagenet2cars_5s5w_wrn28-10_maml: &cross_miniimagenet2cars_5s5w_wrn28-10_maml
    <<: *cross_miniimagenet2cub_5s5w_wrn28-10_maml
    eval_dataset.config_name: "cars"
    target_name: "miniimagenet_5s5w_wrn28-10_maml"
    target_grid:
        lr: [0.0006] # from BOIL paper/repos
        inner_lr: [0.3]  # from BOIL paper/repos
        seed: [1, 2, 3]
cross_miniimagenet2cars_5s5w_wrn28-10_boil:
    <<: *cross_miniimagenet2cars_5s5w_wrn28-10_maml
    target_name: "miniimagenet_5s5w_wrn28-10_boil"
cross_miniimagenet2cars_5s5w_wrn28-10_anil:
    <<: *cross_miniimagenet2cars_5s5w_wrn28-10_maml
    target_name: "miniimagenet_5s5w_wrn28-10_anil"
cross_miniimagenet2cars_5s5w_wrn28-10_ticket:
    <<: *cross_miniimagenet2cars_5s5w_wrn28-10_maml
    target_name: "miniimagenet_5s5w_wrn28-10_ticket"
    target_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
cross_miniimagenet2cars_5s5w_wrn28-10_ticket+boil:
    <<: *cross_miniimagenet2cars_5s5w_wrn28-10_maml
    target_name: "miniimagenet_5s5w_wrn28-10_ticket+boil"
    target_grid:
        lr: [10.0]
        inner_lr: [0.3]  # from BOIL paper/repos
        warmup_iters: [1000]
        init_sparsity: [0.0]
        rerand_freq: [1000]
        rerand_rate: [0.0]
        seed: [1, 2, 3]
# ==================================


__debug__:
    iteration: 200
    valid_batch_size: 32
    validation_freq: 10
    #iteration: 20
    #valid_batch_size: 32
    #validation_freq: 5
    test_batch_size: 128
    print_sparsity: true
