config = {
    "listops":{
        "dataset":{
            "train":96000,
            "dev":2000,
            "test":2000,
        },
        "model":{ #  wide model
            "learn_pos_emb":True,
            "tied_weights":False,
            "embedding_dim":512,
            "transformer_hidden_dim":2048,
            "head_dim":64,
            "num_head":48,
            "num_layers":1,
            "vocab_size":32,
            "max_seq_len":2000,
            "dropout_prob":0.1,
            "attention_dropout":0.1,
            "pooling_mode":"MEAN",
            "num_classes":10,
            "ortho_regularization":  0.0 
        },
        # "model":{ # deep model  
        #     "learn_pos_emb":True,
        #     "tied_weights":False,
        #     "embedding_dim":512,
        #     "transformer_hidden_dim":2048,
        #     "head_dim":64,
        #     "num_head":8,
        #     "num_layers":6,
        #     "vocab_size":32,
        #     "max_seq_len":2000,
        #     "dropout_prob":0.1,
        #     "attention_dropout":0.1,
        #     "pooling_mode":"MEAN",
        #     "num_classes":10,
        #     "ortho_regularization": 0.0 
        # },
        "training":{
            "mixed_precision": False,
            "batch_size":64, 
            "lr_decay": "linear",
            "learning_rate":0.0001,
            "warmup": 1000,
            "weight_decay":0,
            "eval_frequency":50,
            "num_train_steps":20000,
            "num_eval_steps":62,
        },
        "gpu_memory":{
            "softmax":16,
            "performer-256":64,
            "mimoformer": 16, 
            "blockformer": 16
        },
        "extra_attn_config":{
            "softmax":{"attention_grad_checkpointing":True},
            "performer-256":{"attention_grad_checkpointing":False, 
                             "rp_dim":256, 
                             "kernel_type":"relu"},
            "mimoformer":{"attention_grad_checkpointing":False, 
                          "rp_dim":256, 
                          "kernel_type": "relu", 
                          "mimo_M":4, 
                          "mimo_N":4,
                          "MIMO_warmup":3000}, # att. 
            "blockformer":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type": "relu", "mimo_M":2, "mimo_N":2} # att.+MLP
        }
    },
    "image":{
        "dataset":{
            "train":45000,
            "dev":5000,
            "test":10000,
        },
        "model":{ # wide model
            "learn_pos_emb":True,
            "tied_weights":False,
            "embedding_dim":64,
            "transformer_hidden_dim":128,
            "head_dim":64,
            "num_head":12,
            "num_layers":1,
            "vocab_size":512,
            "max_seq_len":1024,
            "dropout_prob":0.1,
            "attention_dropout":0.1,
            "pooling_mode":"MEAN",
           "num_classes": 10,
            "ortho_regularization": 0.0 
        },
        # "model":{ # deep model  
        #     "learn_pos_emb":True,
        #     "tied_weights":False,
        #     "embedding_dim":64,
        #     "transformer_hidden_dim":128,
        #     "head_dim":64,
        #     "num_head":4,
        #     "num_layers":3,
        #     "vocab_size":512,
        #     "max_seq_len":1024,
        #     "dropout_prob":0.1,
        #     "attention_dropout":0.1,
        #     "pooling_mode":"MEAN",
        #     "num_classes": 10,
        #     "ortho_regularization":0.0 
        # },
       "training":{
            "batch_size":256,
            "learning_rate":0.0001,
            "warmup":175,
            "lr_decay":"linear",
            "weight_decay":0,
            "eval_frequency":175,
            "num_train_steps":70000, 
            "num_eval_steps":20
        },
        "gpu_memory":{
            "softmax":128,
            "performer-256": 128, 
            "mimoformer":128, 
            "blockformer": 128
        },
        "extra_attn_config":{
            "softmax":{"attention_grad_checkpointing":False},
            "performer-256":{"attention_grad_checkpointing":False, 
                             "rp_dim":256, 
                             "kernel_type":"relu"},
            "mimoformer":{"attention_grad_checkpointing":False, 
                          "rp_dim":256, 
                          "kernel_type": "relu", 
                          "mimo_M":4, 
                          "mimo_N":4,
                          "MIMO_warmup": 10000}, # att. 
            "blockformer":{"attention_grad_checkpointing":False, 
                           "rp_dim":256, 
                           "kernel_type": "relu", 
                           "mimo_M":2, 
                           "mimo_N":2} # att.+MLP
        }
    },
    "pathfinder32":{
        "dataset":{
            "train":160000,
            "test":20000,
            "dev":20000
        },
        "model":{ # Wide model
            "learn_pos_emb":True,
            "tied_weights":False,
            "embedding_dim":128,
            "transformer_dim":128,
            "transformer_hidden_dim":128,
            "head_dim":128,
            "num_head":32,
            "num_layers":1,
            "vocab_size":512,
            "max_seq_len":1024,
            "dropout_prob":0.1,
            "attention_dropout":0.1,
            "pooling_mode":"MEAN",
            "num_classes": 2,
            "ortho_regularization": 0.0 
        },
        # "model":{ # deep model
        #     "learn_pos_emb":True,
        #     "tied_weights":False,
        #     "embedding_dim":128,
        #     "transformer_dim":128,
        #     "transformer_hidden_dim":128,
        #     "head_dim":128,
        #     "num_head":8,
        #     "num_layers":4,
        #     "vocab_size":512,
        #     "max_seq_len":1024,
        #     "dropout_prob":0.1,
        #     "attention_dropout":0.1,
        #     "pooling_mode":"MEAN",
        #     "num_classes": 2,
        #     "ortho_regularization":0.0 
        # },
        "training":{
            "batch_size":256,
            "learning_rate":0.0001,
            "warmup":312,
            "lr_decay":"linear",
            "weight_decay":0,
            "eval_frequency":312,
            "num_train_steps":124800, 
            "num_eval_steps":312,
        },
        "gpu_memory":{
            "softmax":128,
            "performer-256":128,
            "blockformer": 128,
            "mimoformer": 128
        },
        "extra_attn_config":{
            "softmax":{"attention_grad_checkpointing":True},
            "performer-256":{"attention_grad_checkpointing":False, 
                             "rp_dim":256, 
                             "kernel_type":"relu"},
            "mimoformer":{"attention_grad_checkpointing":False, 
                          "rp_dim":256, 
                          "kernel_type": "relu", 
                          "mimo_M":4, 
                          "mimo_N":4,
                          "MIMO_warmup":20800}, # att. 
            "blockformer":{"attention_grad_checkpointing":False, 
                           "rp_dim":256, 
                           "kernel_type": "relu", 
                           "mimo_M":2, 
                           "mimo_N":2} # att.+MLP
        }
    },
    "retrieval":{
        "dataset":{
            "train":147086,
            "dev":18090,
            "test":17437,
        },
        # "model":{ # wide model
        #     "learn_pos_emb":True,
        #     "tied_weights":False,
        #     "embedding_dim":128,
        #     "transformer_dim":128,
        #     "transformer_hidden_dim":512,
        #     "head_dim":32,
        #     "num_head":16,
        #     "num_layers":1,
        #     "vocab_size":512,
        #     "max_seq_len":4000,
        #     "dropout_prob":0.1,
        #     "attention_dropout":0.1,
        #     "pooling_mode":"MEAN",
        #     "num_classes": 2,
        #     "ortho_regularization": 0.0 # 1e-4
        # },
        # "model":{ # deep model
        #     "learn_pos_emb":True,
        #     "tied_weights":False,
        #     "embedding_dim":128,
        #     "transformer_dim":128,
        #     "transformer_hidden_dim":512,
        #     "head_dim":32,
        #     "num_head":4,
        #     "num_layers":4,
        #     "vocab_size":512,
        #     "max_seq_len":4000,
        #     "dropout_prob":0.1,
        #     "attention_dropout":0.1,
        #     "pooling_mode":"MEAN",
        #     "num_classes": 2,
        #     "ortho_regularization": 0.0 
        # },
        "model":{ # deep model
            "learn_pos_emb":True,
            "tied_weights":False,
            "embedding_dim":32,
            "transformer_dim":32,
            "transformer_hidden_dim":64,
            "head_dim":32,
            "num_head":2,
            "num_layers":2,
            "vocab_size":512,
            "max_seq_len":4000,
            "dropout_prob":0.1,
            "attention_dropout":0.1,
            "pooling_mode":"MEAN",
            "num_classes": 2,
            "ortho_regularization": 0.0 
        },
        "training":{
            "batch_size":32,
            "learning_rate":0.0001,
            "warmup":800,
            "lr_decay": "linear",
            "weight_decay":0,
            "eval_frequency":300,
            "num_train_steps":60000, 
            "num_eval_steps":565,
        },
        "gpu_memory":{
            "softmax":64,
            "performer-256":64,
            "blockformer":64,
            "mimoformer": 64
        },
        "extra_attn_config":{
            "softmax":{"attention_grad_checkpointing":True},
            "performer-256":{"attention_grad_checkpointing":False, 
                             "rp_dim":256, 
                             "kernel_type":"relu"},
            "mimoformer":{"attention_grad_checkpointing":False, 
                          "rp_dim":256, 
                          "kernel_type": "relu", 
                          "mimo_M":4, 
                          "mimo_N":4, 
                          "MIMO_warmup":10000}, # att. 
            "blockformer":{"attention_grad_checkpointing":False, 
                           "rp_dim":256, 
                           "kernel_type": "relu", 
                           "mimo_M":4, 
                           "mimo_N":4} # att.+MLP
        }
    },
    "text":{
        "dataset":{
            "train":25000,
            "dev":25000,
            "test":25000,
        },
        # "model":{ # wide model
        #     "learn_pos_emb":True,
        #     "tied_weights":False,
        #     "embedding_dim": 512,
        #     "transformer_dim":64,
        #     "transformer_hidden_dim": 2048,
        #     "head_dim": 64,
        #     "num_head": 48,
        #     "num_layers": 1,
        #     "vocab_size":512,
        #     "max_seq_len":4000,
        #     "dropout_prob":0.1,
        #     "attention_dropout":0,  # 0.1,
        #     "pooling_mode":"MEAN",
        #     "num_classes": 2,
        #     "ortho_regularization": 0.0
        # },
        "model":{ # deep model
            "learn_pos_emb":True,
            "tied_weights":False,
            "embedding_dim": 512,
            "transformer_dim":64,
            "transformer_hidden_dim": 2048,
            "head_dim": 64,
            "num_head": 8,
            "num_layers": 6,
            "vocab_size":512,
            "max_seq_len":4000,
            "dropout_prob":0.1,
            "attention_dropout":0,  # 
            "pooling_mode":"MEAN",
            "num_classes": 2,
            "ortho_regularization": 0.0 # 
        },
        "training":{
            "mixed_precision": False,
            "batch_size":32,
            "learning_rate":0.0001,
            "warmup":8000,
            "lr_decay":"linear",
            "weight_decay":0,
            "eval_frequency":500, 
            "num_train_steps": 40000, 
            "num_eval_steps": 781, 
        },
        "gpu_memory":{
            "softmax":32,
            "mimoformer":32,
            "blockformer":32,
            "nystrom-32":32,
            "nystrom-64":32,
            "nystrom-128":32,
            "nystrom-256":32,
            "linformer-256":32,
            "reformer-2":32,
            "performer-256":32,
            "linear":32,
        },
        "extra_attn_config":{
            "softmax":{"attention_grad_checkpointing":True},
            "mimoformer":{"attention_grad_checkpointing":False, 
                          "rp_dim":256, 
                          "kernel_type": "relu", 
                          "mimo_M":4, 
                          "mimo_N":4,
                          "MIMO_warmup":6500}, # att. 
            "performer-256":{"attention_grad_checkpointing":False, 
                             "rp_dim":256, 
                             "kernel_type":"relu"},
            "blockformer":{"attention_grad_checkpointing":False, 
                           "rp_dim":256, 
                           "kernel_type": "relu", 
                           "mimo_M":2, 
                           "mimo_N":2} # att.+MLP
        }
    }
}

config["pathfinder32-curv_baseline"] = config["pathfinder32"]
config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"]
config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"]