##########
# Config #
##########
# Order of Diheral group is half of total number of elements in the group 
# Order of Cyclic group is the total number of elements in the group 

base_config: &BASE_CONFIG
    ## SSL parameters
    random_seed: 42
    config: " "
    model: ''
    wandb_log: !!bool False
    wandb_name: ''
    wandb_group: ''
    wandb_project: ''
    wandb_entity: ''
    wandb_log_interval: 1
    debuging: !!bool False
    classification: !!bool True

model_config: &MODEL_CONFIG
    <<: *BASE_CONFIG
    ## CNN parameters
    model: 'cnn'
    # model_architecture:
    num_layers: 3
    # num channels should start with #input channels
    num_channels: [1, 16, 32, 64 ] 
    out_feature: 64
    kernel_sizes: [5, 5, 5]
    spatial_subsampling_factors: [2, 1, 1]
    
    
    data_set: 'mnist'
    fully_convolutional: !!bool False
    ### data_confg:
    batch_size: 128
    test_batch: 200
    num_workers: 1
    data_dir: './data'
    padding: 4
    test_rotation: 60
    test_flip: !!bool False
    train_augmentation: !!bool True
    train_discard_classes: [9,2,4]
    test_discard_classes: [9,2,4]
    num_classes: 10
    ntrain: 5000
    val_ratio: 0.98

    ##  train_config:
    optimizer: 'adam'
    optimizer_kwargs: {}
    lr: 0.002
    weight_decay: 0.00000001

    epochs: 15
    dropout_rate: 0.0

    scheduler: 'step'
    scheduler_kwargs: {'step_size': 20, 'gamma': 0.7}

    loss: 'cross_entropy'

    save_model: !!bool True
    save_model_path: './weights'
    save_model_name: 'cnn_mnist.pth'
    weight_saving_interval: 5

    clip_gradient: !!bool False
    gradient_clip_value: 5.0
    
    ### equivariance_config:
    in_group_type: 'dihedral'
    in_order: 12
    in_feature: 1
    in_representation: 'trivial'
    out_group_type: 'dihedral'
    out_order: 12
    
    out_representation: 'trivial'

    n_tests: 500
    device: 'cuda:0'

g_cnn: &G_CNN
    <<: *MODEL_CONFIG
    ## Graph CNN parameters
    model: 'g_cnn'
    wandb_log: !!bool False

    ## data
    test_flip: !!bool False
    train_augmentation: !!bool False

    ###model_architecture:
    weight_decay: 0.00000001
    num_layers: 3
    num_channels: [1, 8, 16, 32 ] 
    kernel_sizes: [5, 5, 5] 

    # first group is the input group
    dwn_group_types: [['cycle', 'cycle'], ['cycle', 'cycle'],['cycle', 'cycle']] 
    dwn_orders: [[4,4], [4,4], [4,4]] 
    subsampling_factors: [1, 1, 1] 
    domain: 2
    layer_kwargs: {'dilation': 2}
    apply_antialiasing: !!bool True
    cannonicalize: !!bool False
    
    pooling_type: 'max'

    ## equivariance_config:
    in_group_type: 'cycle'
    in_order: 4
    in_feature: 1
    in_representation: 'trivial'
    out_group_type: 'cycle'
    out_order: 4
    out_feature: 32
    out_representation: 'regular'
    
    device: 'cuda:0'
    save_model_name: 'g_cnn_mnist.pth'

g_cnn_d24: &G_CNN_D24
    <<: *G_CNN
    model: 'g_cnn'
    ## Graph CNN parameters

    test_flip: !!bool True

    test_rotation: 60

    ## model_architecture:
    dwn_group_types: [['dihedral', 'dihedral'], ['dihedral', 'dihedral'],['dihedral', 'dihedral']] 
    dwn_orders: [[12,12], [12,12], [12,12]] #, [12,12]]

    ## equivariance_config:
    in_group_type: 'dihedral'
    in_feature: 1
    in_representation: 'trivial'
    out_group_type: 'dihedral'
    out_representation: 'regular'

    sample_type: 'sample'

    in_order: 36
    out_order: 36
    save_model_name: 'g_cnn_d24_mnist.pth'

g_cnn_c24: &G_CNN_C24
    <<: *G_CNN_D24
    model: 'g_cnn'

    test_rotation: 60
    test_flip: !!bool False

    ## model_architecture:

    dwn_group_types: [['cycle', 'cycle'], ['cycle', 'cycle'],['cycle', 'cycle']]
    dwn_orders: [[24,24], [24,24], [24,24]] 
    subsampling_factors: [1, 1, 1] #, 1]

    ## equivariance_config:
    in_group_type: 'cycle'
    in_order: 18
    in_feature: 1
    in_representation: 'trivial'
    out_group_type: 'cycle'
    out_order: 18
    out_representation: 'regular'

    sample_type: 'sample'
    save_model_name: 'g_cnn_c24_mnist.pth'    


g_cnn_dwn_d24_1: &G_CNN_DWN_D24_1
    <<: *G_CNN_D24
    model: 'g_cnn_dwn'
    ## Graph CNN parameters


    test_rotation: 60

    ## model_architecture:

    dwn_group_types: [['dihedral', 'dihedral'], ['dihedral', 'dihedral'],['dihedral', 'dihedral']] 
    dwn_orders: [[12,12], [12,6], [6,3]] 
    subsampling_factors: [1, 2, 2]
    spatial_subsampling_factors: [1, 1, 1]

    apply_antialiasing: !!bool True
    antialiasing_kwargs: {'smooth_operator': 'adjacency', 'mode': 'linear_optim','iterations': 500,
                            'smoothness_loss_weight': 5.0,'threshold': 0.0,
                            'equi_constraint': !!bool True,
                            'equi_correction': !!bool True}
    cannonicalize: !!bool False
    
    ## equivariance_config:

    out_group_type: 'dihedral'
    out_order: 6
    out_representation: 'regular'

    sample_type: 'sample'
    save_model_name: 'g_cnn_dwn_d24_1_mnist.pth'


g_cnn_dwn_d24_f2: &G_CNN_DWN_D24_F2
    <<: *G_CNN_DWN_D24_1
    data_set: 'mnist'
    dwn_group_types: [['dihedral', 'dihedral'], ['dihedral', 'dihedral'],['dihedral', 'dihedral']] 
    dwn_orders: [[12,12], [12,6], [6,6]]
    subsampling_factors: [1, 2, 1] 
    spatial_subsampling_factors: [2, 1, 1]

    out_group_type: 'dihedral'
    out_order: 6
    save_model_name: 'g_cnn_dwn_d24_F2_mnist.pth'


g_cnn_dwn_d24_f4: &G_CNN_DWN_D24_F4
    <<: *G_CNN_DWN_D24_F2
    data_set: 'mnist'
    dwn_group_types: [['dihedral', 'dihedral'], ['dihedral', 'dihedral'],['dihedral', 'dihedral']] 
    dwn_orders: [[12,12], [12,3], [3,3]]
    subsampling_factors: [1, 4, 1]

    out_group_type: 'dihedral'
    out_order: 3
    save_model_name: 'g_cnn_dwn_d24_F4_mnist.pth'

g_cnn_dwn_d24_f3: &G_CNN_DWN_D24_F3
    <<: *G_CNN_DWN_D24_F2
    data_set: 'mnist'
    dwn_group_types: [['dihedral', 'dihedral'], ['dihedral', 'dihedral'],['dihedral', 'dihedral']] 
    dwn_orders: [[12,12], [12,4], [4,4]]
    subsampling_factors: [1, 3, 1]

    antialiasing_kwargs: {'smooth_operator': 'adjacency', 'mode': 'linear_optim','iterations': 500,
                            'smoothness_loss_weight': 8.0,'threshold': 0.0,
                            'equi_constraint': !!bool True,
                            'equi_correction': !!bool True}

    out_group_type: 'dihedral'
    out_order: 4
    save_model_name: 'g_cnn_dwn_d24_F3_mnist.pth'

g_cnn_dwn_C24_f2: &G_CNN_DWN_C24_F2
    <<: *G_CNN_DWN_D24_F2
    model: 'g_cnn_dwn'
    ## Graph CNN parameters

    ## test config

    test_flip: !!bool False


    ## model_architecture:

    dwn_group_types: [['cycle', 'cycle'], ['cycle', 'cycle'],['cycle', 'cycle']] 
    dwn_orders: [[24,24], [24,12], [12,12]] 
    subsampling_factors: [1, 2, 1]

    
    ## equivariance_config:
    in_group_type: 'cycle'
    in_order: 18

    out_group_type: 'cycle'
    out_order: 18

    sample_type: 'sample'
    save_model_name: 'g_cnn_dwn_C24_f2_mnist.pth'

g_cnn_dwn_C24_f3: &G_CNN_DWN_C24_F3
    <<: *G_CNN_DWN_C24_F2
    dwn_group_types: [['cycle', 'cycle'], ['cycle', 'cycle'],['cycle', 'cycle']] 
    dwn_orders: [[24,24], [24,8], [8,8]] 
    subsampling_factors: [1, 3, 1] 

    out_order: 8
    save_model_name: 'g_cnn_dwn_C24_f3_mnist.pth'


g_cnn_dwn_C24_f4: &G_CNN_DWN_C24_F4
    <<: *G_CNN_DWN_C24_F2
    dwn_group_types: [['cycle', 'cycle'], ['cycle', 'cycle'],['cycle', 'cycle']] 
    dwn_orders: [[24,24], [24,6], [6,6]] 
    subsampling_factors: [1, 4, 1] 

    out_order: 6
    save_model_name: 'g_cnn_dwn_C24_f4_mnist.pth'
      


g_cnn_cifar: &G_CNN_CIFAR
    <<: *G_CNN
    model: 'g_cnn'
    wandb_project: ''
    data_set: 'cifar10'
    num_channels: [3, 4, 16, 64]
    kernel_sizes: [5, 5, 5]
    spatial_subsampling_factors: [2, 1, 1]
    num_classes: 10
    epochs: 50
    batch_size: 256
    test_batch: 512
    padding: 7
    dropout_rate: 0.3
    weight_decay: 0.00000001
    lr: 0.005
    scheduler_kwargs: {'step_size': 25, 'gamma': 0.7}
    ntrain: 59999
    val_ratio: 0.98
    test_flip: !!bool False
    train_augmentation: !!bool False
    train_discard_classes: []
    test_discard_classes: []
    in_feature: 3
    out_feature: 64
    save_model_name: 'g_cnn_cifar.pth'

g_cnn_cifar_d24: &G_CNN_CIFAR_D24
    <<: *G_CNN_CIFAR
    data_set: 'cifar10'

    test_flip: !!bool True

    test_rotation: 30

    ## model_architecture:
    dwn_group_types: [['dihedral', 'dihedral'], ['dihedral', 'dihedral'],['dihedral', 'dihedral']] 
    dwn_orders: [[12,12], [12,12], [12,12]]
    subsampling_factors: [1, 1, 1] 

    ## equivariance_config:
    in_group_type: 'dihedral'
    in_order: 36
    in_feature: 3
    in_representation: 'trivial'
    out_group_type: 'dihedral'
    out_order: 36
    out_feature: 64
    out_representation: 'regular'
    sample_type: 'sample'
    
    save_model_name: 'g_cnn_cifar_d24.pth'

g_cnn_cifar_c24: &G_CNN_CIFAR_C24
    <<: *G_CNN_CIFAR_D24
    test_flip: !!bool False

    ## model_architecture:

    dwn_group_types: [['cycle', 'cycle'], ['cycle', 'cycle'],['cycle', 'cycle']] 
    dwn_orders: [[24,24], [24,24], [24,24]] 
    subsampling_factors: [1, 1, 1] 

    ## equivariance_config:
    in_group_type: 'cycle'
    in_feature: 3
    in_representation: 'trivial'
    out_group_type: 'cycle'
    out_feature: 64
    out_representation: 'regular'
    sample_type: 'sample'
    save_model_name: 'g_cnn_cifar_c24.pth'

g_cnn_cifar_dwn_d24_f2: &G_CNN_CIFAR_DWN_D24_F2
    <<: *G_CNN_CIFAR_D24
    model: 'g_cnn_dwn'
    dwn_group_types: [['dihedral', 'dihedral'], ['dihedral', 'dihedral'],['dihedral', 'dihedral']] 
    dwn_orders: [[12,12], [12,6], [6,6]]
    subsampling_factors: [1, 2, 1] 

    antialiasing_kwargs: {'smooth_operator': 'adjacency', 'mode': 'linear_optim','iterations': 500,
                            'smoothness_loss_weight': 8.0,'threshold': 0.0,
                            'equi_constraint': !!bool True,
                            'equi_correction': !!bool True}

    cannonicalize: !!bool False
    apply_antialiasing: !!bool True

    out_order: 6
    save_model_name: 'g_cnn_dwn_d24_F2_cifar.pth'

g_cnn_cifar_dwn_d24_f3: &G_CNN_CIFAR_DWN_D24_F3
    <<: *G_CNN_CIFAR_DWN_D24_F2
    model: 'g_cnn_dwn'
    dwn_group_types: [['dihedral', 'dihedral'], ['dihedral', 'dihedral'],['dihedral', 'dihedral']] 
    dwn_orders: [[12,12], [12,4], [4,4]]
    subsampling_factors: [1, 3, 1]
    out_order: 4
    save_model_name: 'g_cnn_dwn_d24_F3_cifar.pth'

g_cnn_cifar_dwn_d24_f4: &G_CNN_CIFAR_DWN_D24_F4
    <<: *G_CNN_CIFAR_DWN_D24_F2
    model: 'g_cnn_dwn'
    dwn_group_types: [['dihedral', 'dihedral'], ['dihedral', 'dihedral'],['dihedral', 'dihedral']] 
    dwn_orders: [[12,12], [12,3], [3,3]]
    subsampling_factors: [1, 4, 1]
    antialiasing_kwargs: {'smooth_operator': 'adjacency', 'mode': 'linear_optim','iterations': 2000,
                            'smoothness_loss_weight': 8.0,'threshold': 0.0,
                            'equi_constraint': !!bool True,
                            'equi_correction': !!bool True}
    out_group_type: 'dihedral'
    out_order: 3
    save_model_name: 'g_cnn_dwn_d24_F4_cifar.pth'

g_cnn_cifar_dwn_c24_f2: &G_CNN_CIFAR_DWN_C24_F2
    <<: *G_CNN_CIFAR_C24
    model: 'g_cnn_dwn'
    dwn_group_types: [['cycle', 'cycle'], ['cycle', 'cycle'],['cycle', 'cycle']] 
    dwn_orders: [[24,24], [24,12], [12,12]] 
    subsampling_factors: [1, 2, 1]

    antialiasing_kwargs: {'smooth_operator': 'adjacency', 'mode': 'linear_optim','iterations': 500,
                            'smoothness_loss_weight': 8.0,'threshold': 0.0,
                            'equi_constraint': !!bool True,
                            'equi_correction': !!bool True}

    cannonicalize: !!bool False
    apply_antialiasing: !!bool True

    out_order: 12
    save_model_name: 'g_cnn_dwn_C24_f2_cifar.pth'

g_cnn_cifar_dwn_c24_f3: &G_CNN_CIFAR_DWN_C24_F3
    <<: *G_CNN_CIFAR_DWN_C24_F2
    model: 'g_cnn_dwn'
    dwn_group_types: [['cycle', 'cycle'], ['cycle', 'cycle'],['cycle', 'cycle']] 
    dwn_orders: [[24,24], [24,8], [8,8]] 
    subsampling_factors: [1, 3, 1] 

    out_order: 8
    save_model_name: 'g_cnn_dwn_C24_f3_cifar.pth'

g_cnn_cifar_dwn_c24_f4: &G_CNN_CIFAR_DWN_C24_F4
    <<: *G_CNN_CIFAR_DWN_C24_F2
    model: 'g_cnn_dwn'
    dwn_group_types: [['cycle', 'cycle'], ['cycle', 'cycle'],['cycle', 'cycle']] 
    dwn_orders: [[24,24], [24,6], [6,6]] 
    subsampling_factors: [1, 4, 1] 

    out_order: 6
    save_model_name: 'g_cnn_dwn_C24_f4_cifar.pth'  