
########################################################
#   General settings
########################################################

# -- Definitions of datasets --

cifar10: 
    class: 'CIFAR10'
    data_type: 'image'
    num_channels: 3
    image_size: 32
    num_classes: 10
    specify_classes: null

mnist:
    class: 'MNIST'
    data_type: 'image'
    num_channels: 1
    image_size: 28
    num_classes: 10
    specify_classes: null

imagenet:
    class: 'ImageNet'
    data_type: 'image'
    num_channels: 3
    image_size: 224
    num_classes: 1000
    specify_classes: null

cars:
    class: 'StanfordCars'
    data_type: 'image'
    num_channels: 3
    image_size: 224
    num_classes: 196
    specify_classes: null

cub:
    class: 'Cub2011'
    data_type: 'image'
    num_channels: 3
    image_size: 224
    num_classes: 200
    specify_classes: null

cifar100:
    class: 'CIFAR100'
    data_type: 'image'
    num_channels: 3
    image_size: 32
    num_classes: 100
    specify_classes: null

# -- Definitions of networks --

mlp:
    class: 'MLP'
    default_width: 64
mlp1024:
    class: 'MLP'
    default_width: 1024

conv4:
    class: 'Conv4'
conv6:
    class: 'Conv6'
conv8:
    class: 'Conv8'

resnet18: &__resnet18__
    class: 'ResNet'
    block_class: 'BasicBlock'
    num_blocks: [2, 2, 2, 2]

resnet34: &__resnet34__
    class: 'ResNet'
    block_class: 'BasicBlock'
    num_blocks: [3, 4, 6, 3]

resnet50:
    class: 'ResNet'
    block_class: 'Bottleneck'
    num_blocks: [3, 4, 6, 3]

resnet101:
    class: 'ResNet'
    block_class: 'Bottleneck'
    num_blocks: [3, 4, 23, 3]


# -- all options --

__default__: &__default__

    # General Setting
    num_workers: 4
    use_cuda: True
    output_dir: '__outputs__'
    dataset_dir: '__data__'
    sync_dir: '__sync__'
    checkpoint_epochs: []
    seed: null
    seed_by_time: false
    dataset_download: true
    num_gpus: 1
    debug_max_iters: null
    allowed_commands: null
    output_prefix_hashing: true

    load_checkpoint_by_path: null
    load_checkpoint_by_hparams: null
    seed_checkpoint: null
    initialize_head: false

    dataset.config_name: null
    train_val_split: 0.1
    model.config_name: null
    save_best_model: true
    print_train_loss: false

    # Hyperparameters for Training
    epoch: null
    optimizer: "SGD"
    lr: null
    weight_decay: null
    lr_scheduler: null
    warmup_epochs: 0
    finetuning_epochs: 0
    finetuning_lr: null
    sgd_momentum: 0.9
    lr_milestones: null
    multisteplr_gamma: 0.1
    padding_before_crop: False
    train_augmentation: True
    learning_framework: "ImageClassification"

    batch_size: 128
    batch_size_eval: 512
    max_train_dataset_size: null
    bn_track_running_stats: true
    bn_affine: true
    bn_momentum: 0.1
    width_factor: 1.0

    # Hyperparameter Search Setting
    hparams_grid: null

    epoch_ft: null
    lr_ft: null
    weight_decay_ft: null
    optimizer_ft: -1 # undefined
    lr_scheduler_ft: -1 # undefined
    use_best_for_ft: true

    gm_iters: 1
    bn_stats_iters: 0
    bn_stats_batch: null
    base_matching_coeff: 0.0
    num_splits: null
    normalized_matching: false
    baseline_no_perm: false
    baseline_use_target: false
    split_scheduling_fn: null
    fast_perm_optim: false
    wm_epsilon: 1.0e-7
    wm_use_prev_perm: false

    num_perms: 1
    branch_at_best: true

    source_target_seeds: null
    source_target_seeds_pretrained: null


########################################################
#   Default settings for training on CIFAR-10 and ImageNet 
########################################################

cifar10_sgd: &cifar10_sgd
    <<: *__default__
    dataset.config_name: 'cifar10'
    padding_before_crop: True

    epoch: 100
    batch_size: 128
    optimizer: "SGD"
    lr_scheduler: "CustomCosineLR"
    warmup_epochs: 0

    # Override these options
    lr: 0.1
    model.config_name: null
    weight_decay: null # 0.0001 for convs, 0.0005 for resnets

imagenet_sgd: &imagenet_sgd
    <<: *__default__
    dataset.config_name: 'imagenet'
    model.config_name: null
    num_workers: 8

    epoch: 100
    batch_size: 128
    sgd_momentum: 0.9
    weight_decay: 0.0001
    optimizer: "SGD"

    lr: 0.1
    lr_scheduler: "CustomCosineLR"
    warmup_epochs: 5
    finetuning_epochs: 0

########################################################
#   Pre-Training
########################################################

imagenet_resnet18_sgd: &imagenet_resnet18_sgd
    <<: *imagenet_sgd
    num_gpus: 1
    model.config_name: "resnet18"
    allowed_commands: [train]

    checkpoint_epochs: "list(range(-1, 100))"

    hparams_grid:
        seed: [101, 102, 201, 202, 301, 302]

imagenet_to_cars_resnet18_sgd: &imagenet_to_cars_resnet18_sgd
    <<: *imagenet_resnet18_sgd
    dataset.config_name: 'cars'
    allowed_commands: [train]

    initialize_head: true
    load_checkpoint_by_hparams:
        _exp_name: imagenet_resnet18_sgd
        seed: null  # should be set by `seed_checkpoint` option

    lr: 0.1 # chosen from {0.1, 0.001, 0.0001}
    lr_scheduler: "CustomCosineLR"
    warmup_epochs: 0
    finetuning_epochs: 0

    weight_decay: 0.0001
    hparams_grid:
        epoch: [30]
        #seed: [777]
        #seed_checkpoint: [101, 201, 301]
        seed: [888]
        seed_checkpoint: [102, 202, 302]

imagenet_to_cub_resnet18_sgd: &imagenet_to_cub_resnet18_sgd
    <<: *imagenet_resnet18_sgd
    dataset.config_name: 'cub'
    allowed_commands: [train]

    initialize_head: true
    load_checkpoint_by_hparams:
        _exp_name: imagenet_resnet18_sgd
        seed: null  # should be set by `seed_checkpoint` option

    lr: 0.1 # chosen from {0.1, 0.001, 0.0001}
    lr_scheduler: "CustomCosineLR"
    warmup_epochs: 0
    finetuning_epochs: 0

    weight_decay: 0.0001
    hparams_grid:
        epoch: [30]
        seed: [777]
        seed_checkpoint: [101, 201, 301]
        #seed: [888]
        #seed_checkpoint: [102, 202, 302]

cifar10_conv8_sgd: &cifar10_conv8_sgd
    <<: *cifar10_sgd
    model.config_name: "conv8"
    allowed_commands: [train]

    checkpoint_epochs: "list(range(-1, 60))"
    weight_decay: 0.0001
    hparams_grid:
        epoch: [60]
        lr: [0.05]
        width_factor: [1.0]
        seed: [101, 102, 103, 201, 202, 203, 301, 302, 303]

cifar10_to_cifar100_conv8_sgd: &cifar10_to_cifar100_conv8_sgd
    <<: *cifar10_conv8_sgd
    dataset.config_name: 'cifar100'
    padding_before_crop: True

    allowed_commands: [train]

    initialize_head: true
    load_checkpoint_by_hparams:
        _exp_name: cifar10_conv8_sgd
        epoch: 60
        lr: 0.05
        width_factor: 1.0
        seed: null  # should be set by `seed_checkpoint` option

    checkpoint_epochs: "list(range(-1,30))"
    weight_decay: 0.0001
    hparams_grid:
        epoch: [30]
        lr: [0.05]
        #seed: [777]
        #seed_checkpoint: [101, 201, 301]
        seed: [888]
        seed_checkpoint: [102, 202, 302]
cifar100_conv8_sgd:
    <<: *cifar10_to_cifar100_conv8_sgd
    initialize_head: false
    load_checkpoint_by_hparams: null

mnist_mlp_sgd: &mnist_mlp_sgd
    <<: *__default__
    allowed_commands: [train]
    checkpoint_epochs: "list(range(-1, 100))"

    dataset.config_name: 'mnist'
    model.config_name: 'mlp'

    optimizer: "SGD"
    lr_scheduler: "CustomCosineLR"
    warmup_epochs: 0
    weight_decay: 0.0001

    hparams_grid:
        epoch: [15]
        batch_size: [64]
        lr: [0.05]
        width_factor: [8.0]  # dim=512
        seed: [101,102,201,202,301,302]

########################################################
#   Command: transfer
########################################################

transfer_init-mnist_mlp: &transfer_init-mnist_mlp
    <<: *mnist_mlp_sgd
    allowed_commands: [transfer]
    source_exp_name: mnist_mlp_sgd
    source_hparams:
        epoch: 15
        batch_size: 64
        lr: 0.05
        width_factor: 8.0  # dim=512
        seed: null
    target_exp_name: mnist_mlp_sgd
    target_hparams_grid:
        epoch: [15]
        batch_size: [64]
        lr: [0.05]
        width_factor: [8.0]  # dim=512
        seed: [null]
    interpolation_search: false
    diff_grad_matching: false
    output_prefix_hashing: true
    faster_version: true
    fast_perm_optim: false
    hparams_grid:
        epoch_ft: [15]
        lr_ft: [0.05]
        optimizer_ft: ["SGD"]
        lr_scheduler_ft: ["CustomCosineLR"]
        weight_decay_ft: [0.0001]

        base_matching_coeff: [0.0]
        normalized_matching: [false]
        gm_iters: [1]
        checkpoint_epochs: [[-1, 14]]
        num_perms: [1]
        num_splits: [5]
        split_scheduling_fn: [null]
        branch_at_best: [true]
        source_target_seeds: [[101, 102], [201, 202], [301, 302]]
transfer_init-mnist_mlp_fast:
    <<: *transfer_init-mnist_mlp
    fast_perm_optim: true
transfer_init-mnist_mlp_noperm: &transfer_init-mnist_mlp_noperm
    <<: *transfer_init-mnist_mlp
    baseline_no_perm: true
transfer_init-mnist_mlp_usetarget:
    <<: *transfer_init-mnist_mlp_noperm
    baseline_no_perm: false
    baseline_use_target: true

transfer_init-cifar10_conv8: &transfer_init-cifar10_conv8
    <<: *cifar10_conv8_sgd
    allowed_commands: [transfer]
    checkpoint_epochs: "list(range(-1, 60))"
    source_exp_name: cifar10_conv8_sgd
    source_hparams:
        epoch: 60
        lr: 0.05
        width_factor: 1.0
        seed: null
    target_exp_name: cifar10_conv8_sgd
    target_hparams_grid:
        epoch: [60]
        lr: [0.05]
        width_factor: [1.0]
        seed: [null]
    interpolation_search: false
    diff_grad_matching: false
    output_prefix_hashing: true
    faster_version: true
    fast_perm_optim: false
    hparams_grid:
        epoch_ft: [60]
        lr_ft: [0.01]
        weight_decay_ft: [0.0001]
        optimizer_ft: ['SGD']
        lr_scheduler_ft: ["CustomCosineLR"]

        base_matching_coeff: [0.0]
        normalized_matching: [false]
        gm_iters: [2]
        checkpoint_epochs: [[-1, 59]]
        num_perms: [1]
        num_splits: [30]

        branch_at_best: [true]
        source_target_seeds: [[101, 102], [201, 202], [301, 302]]
        #source_target_seeds: [[101, 102], [101, 202], [101, 302]]  # **For Ensemble Evaluation**
transfer_init-cifar10_conv8_fast:
    <<: *transfer_init-cifar10_conv8
    fast_perm_optim: true
transfer_init-cifar10_conv8_noperm: &transfer_init-cifar10_conv8_noperm
    <<: *transfer_init-cifar10_conv8
    baseline_no_perm: true
transfer_init-cifar10_conv8_usetarget:
    <<: *transfer_init-cifar10_conv8_noperm
    baseline_no_perm: false
    baseline_use_target: true


transfer_cifar10-cifar100_conv8: &transfer_cifar10-cifar100_conv8
    <<: *transfer_init-cifar10_conv8
    source_exp_name: cifar10_to_cifar100_conv8_sgd
    source_hparams:
        epoch: 60
        lr: 0.01
        seed: 777
        seed_checkpoint: null
    target_exp_name: cifar10_to_cifar100_conv8_sgd
    target_hparams_grid:
        epoch: [60]
        lr: [0.01]
        seed: [888]
        seed_checkpoint: [null]
    interpolation_search: false
    diff_grad_matching: false
    output_prefix_hashing: true
    faster_version: true

    fast_perm_optim: false

    hparams_grid:
        epoch_ft: [60]
        lr_ft: [0.01]
        weight_decay_ft: [0.0001]
        optimizer_ft: ['SGD']
        lr_scheduler_ft: ["CustomCosineLR"]

        normalized_matching: [false]
        base_matching_coeff: [0.0]
        gm_iters: [2]

        checkpoint_epochs: [[-1, 59]]
        #split_scheduling_fn: ['1+math.cos((t/T)*math.pi)']
        split_scheduling_fn: [null]
        num_perms: [1]
        num_splits: [15]

        l2_coeff_to_prev_perm: [0.01]
        branch_at_best: [true]
        source_target_seeds_pretrained: [[101, 102], [201, 202], [301, 302]]
        #source_target_seeds_pretrained: [[101, 102], [101, 202], [101, 302]] # **For Ensemble Evaluation**
transfer_cifar10-cifar100_conv8_fast:
    <<: *transfer_cifar10-cifar100_conv8
    fast_perm_optim: true
transfer_cifar10-cifar100_conv8_noperm:
    <<: *transfer_cifar10-cifar100_conv8
    baseline_no_perm: true
transfer_cifar10-cifar100_conv8_usetarget:
    <<: *transfer_cifar10-cifar100_conv8
    baseline_no_perm: false
    baseline_use_target: true

transfer_init-imagenet_resnet18: &transfer_init-imagenet_resnet18
    <<: *imagenet_resnet18_sgd
    <<: *imagenet_resnet18_sgd
    allowed_commands: [transfer]
    source_exp_name: imagenet_resnet18_sgd
    source_hparams:
        seed: null
    target_exp_name: imagenet_resnet18_sgd
    target_hparams_grid:
        seed: [null]
    interpolation_search: false
    diff_grad_matching: false

    output_prefix_hashing: true
    fast_perm_optim: false
    l2_reg_type: 'diff' # {'grad', 'diff'}

    hparams_grid:
        epoch_ft: [100]
        lr_ft: [0.001]
        weight_decay_ft: [0.0001]
        optimizer_ft: ['SGD']
        lr_scheduler_ft: ["CustomCosineLR"]

        batch_size: [128]
        bn_stats_iters: [20]  # for batch norms
        bn_stats_batch: [128]  # for batch norms
        gm_iters: [100]  # for ImageNet

        checkpoint_epochs: [[-1, 69]]
        num_splits: [40]

        num_perms: [1]
        branch_at_best: [true]

        source_target_seeds: [[101, 102], [201, 202], [301, 302]]
transfer_init-imagenet_resnet18_fast:
    <<: *transfer_init-imagenet_resnet18
    fast_perm_optim: true
transfer_init-imagenet_resnet18_noperm: &transfer_init-imagenet_resnet18_noperm
    <<: *transfer_init-imagenet_resnet18
    baseline_no_perm: true
transfer_init-imagenet_resnet18_usetarget:
    <<: *transfer_init-imagenet_resnet18_noperm
    baseline_no_perm: false
    baseline_use_target: true

transfer_imagenet-cars_resnet18: &transfer_imagenet-cars_resnet18
    <<: *imagenet_to_cars_resnet18_sgd
    allowed_commands: [transfer]
    source_exp_name: imagenet_to_cars_resnet18_sgd
    source_hparams:
        epoch: 30
        seed: 777
        seed_checkpoint: null
    target_exp_name: imagenet_to_cars_resnet18_sgd
    target_hparams_grid:
        epoch: [30]
        seed: [888]
        seed_checkpoint: [null]

    output_prefix_hashing: true
    fast_perm_optim: false

    hparams_grid:
        epoch_ft: [30]
        lr_ft: [0.05] # for Cars
        optimizer_ft: ['SGD']
        lr_scheduler_ft: ["CustomCosineLR"]
        weight_decay_ft: [0.0001]

        batch_size: [128]
        bn_stats_iters: [2]  # for batch norms
        bn_stats_batch: [128]  # for batch norms
        gm_iters: [5]

        checkpoint_epochs: [[-1, 29]]
        num_splits: [15]
        num_perms: [1]

        branch_at_best: [true]
        source_target_seeds_pretrained: [[101, 102], [201, 202], [301, 302]]
        #source_target_seeds_pretrained: [[101, 102], [101, 202], [101, 302]] # **For Ensemble Evaluation**
transfer_imagenet-cars_resnet18_fast:
    <<: *transfer_imagenet-cars_resnet18
    fast_perm_optim: false
transfer_imagenet-cars_resnet18_noperm: &transfer_imagenet-cars_resnet18_noperm
    <<: *transfer_imagenet-cars_resnet18
    baseline_no_perm: true
transfer_imagenet-cars_resnet18_usetarget:
    <<: *transfer_imagenet-cars_resnet18_noperm
    baseline_no_perm: false
    baseline_use_target: true

transfer_imagenet-cub_resnet18: &transfer_imagenet-cub_resnet18
    <<: *imagenet_to_cub_resnet18_sgd
    allowed_commands: [transfer]
    source_exp_name: imagenet_to_cub_resnet18_sgd
    source_hparams:
        epoch: 30
        seed: 777
        seed_checkpoint: null
    target_exp_name: imagenet_to_cub_resnet18_sgd
    target_hparams_grid:
        epoch: [30]
        seed: [888]
        seed_checkpoint: [null]

    output_prefix_hashing: true
    fast_perm_optim: false

    hparams_grid:
        epoch_ft: [30]
        lr_ft: [0.01] # for CUB
        optimizer_ft: ['SGD']
        lr_scheduler_ft: ["CustomCosineLR"]
        weight_decay_ft: [0.0001]

        batch_size: [128]
        bn_stats_iters: [2]  # for batch norms
        bn_stats_batch: [128]  # for batch norms
        gm_iters: [5]

        checkpoint_epochs: [[-1, 29]]
        num_splits: [15]
        num_perms: [1]

        branch_at_best: [true]
        source_target_seeds_pretrained: [[101, 102], [201, 202], [301, 302]]
        #source_target_seeds_pretrained: [[101, 102], [101, 202], [101, 302]] # **For Ensemble Evaluation**
transfer_imagenet-cub_resnet18_fast:
    <<: *transfer_imagenet-cub_resnet18
    fast_perm_optim: false
transfer_imagenet-cub_resnet18_noperm: &transfer_imagenet-cub_resnet18_noperm
    <<: *transfer_imagenet-cub_resnet18
    baseline_no_perm: true
transfer_imagenet-cub_resnet18_usetarget:
    <<: *transfer_imagenet-cub_resnet18_noperm
    baseline_no_perm: false
    baseline_use_target: true


