# measured in train/epochs - unfortunately this is different across tasks

TASK_METRICS = {
    'cola': 'eval/matthews_correlation',
    'mnli': 'eval/accuracy',
    'sst2': 'eval/accuracy',
    'stsb': 'eval/pearson',
    'qnli': 'eval/accuracy',
    'qqp': 'eval/accuracy',
    'rte': 'eval/accuracy',
    'mrpc': 'eval/accuracy'
}

GLUE = {
    "lora_random_vs_pca": {
        "experiment_name": ["lora_lr_grid_randinit"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "vera_random_vs_pca": {
        "experiment_name": ["vera_lr_grid", "vera_lr_grid_pcainit", "vera_lr_grid_pcainit_scaled", "vera_lr_grid_pcainit_scaled_trainB", "vera_lr_grid_pcainit_trainB"],
        # "experiment_name": ["vera_lr_grid_pcainit_scaled_trainB"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "dora_random_vs_pca": {
        "experiment_name": ["dora_init_rand", "dora_init_pca", "dora_init_pca_scaled", "dora_lr_grid_randinit", "dora_lr_grid_pcainit", "dora_lr_grid_pcainit_scaled"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "lora_pca++": {
        "experiment_name": ["lora_lr_grid_pcainit_adapt_redist", "lora_grid_pca++", "lora_pca++_scaled_from_scratch"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "lora_pca_prebn": {
        "experiment_name": ["lora_prebn_lr_grid_pcainit_v2", "lora_prebn_lr_grid_pcainit", "lora_postln_lr_grid_pcainit"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "lora_grid_v2": {
        "experiment_name": ["lora_grid_v2", "lora_large_bs_small_lr"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "lora_grid_pcainit_v2": {
        "experiment_name": ["lora_grid_pcainit_v2"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "loram_init_pca": {
        "experiment_name": ["loram_init_pca"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate", "loram_lr_ratio"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate", "loram_lr_ratio"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_large_random": {
        "experiment_name": ["roberta_large_randinit"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_large_pca": {
        "experiment_name": ["roberta_large_pcainit"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_large_pcainit_redistribute": {
        "experiment_name": ["roberta_large_pcainit_redistribute"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_large_pcainit_scaled": {
        "experiment_name": ["roberta_large_pcainit_scaled"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_base_grid_pca++_raw": {
        "experiment_name": ["roberta_base_grid_pca++_raw"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_base_grid_pca++_ratio": {
        "experiment_name": ["roberta_base_grid_pca++_ratio"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "lr_grid_pcainit_whitened": {
        "experiment_name": ["lr_grid_pcainit_whitened"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_base_redistributed": {
        "experiment_name": ["roberta_base_grid_pca++_sum", "roberta_base_grid_pca++_ratio", "roberta_base_grid_pca++_max", "roberta_base_grid_pca++_raw"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_base_redistributed_init_cls": {
        "experiment_name": ["roberta_base_grid_pca++_ratio_init_cls"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_large_redistributed": {
        "experiment_name": ["roberta_large_grid_pca++_max", "roberta_large_grid_pca++_ratio", "roberta_large_grid_pca++_raw", "roberta_large_grid_pca++_sum"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_large_pca_whitened": {
        "experiment_name": ["roberta_large_grid_pcainit_whitened"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_large_pca_whitened_redistributed": {
        "experiment_name": ["roberta_large_grid_pca++_sum_whitened"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "alpha_grid_roberta_base_random": {
        "experiment_name": ["alpha_grid_randominit"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "lora_alpha"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "lora_alpha"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "debertav3_randinit": {
        "experiment_name": ["debertav3_randinit"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_large_dora_randinit": {
        "experiment_name": ["roberta_large_dora_randinit"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "roberta_large_eva": {
        "experiment_name": ["roberta_large_grid_pca++_ratio"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
    "debertav3_dora_randinit": {
        "experiment_name": ["debertav3_dora_randinit"],
        "table_kwargs": {
            "step": "max",
            "metric_col": ["eval/accuracy"],
            "groupby": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "sortby_col": ["experiment_name", "task_name", "lora_r", "learning_rate"],
            "add_avg": False,
            "transpose": False,
            "drop_idx": True,
            "round": 3
        },
    },
}

NEXT_EXP = {
    "test": {}
}

ALL = {**GLUE, **NEXT_EXP}


def load_exp_config(exp_name):
    assert exp_name in ALL, "Unknown experiment configuration"
    exp_names = ALL[exp_name].get("experiment_name", {})
    table_kwargs = ALL[exp_name].get("table_kwargs", {})
    table_filters = ALL[exp_name].get("table_filters", None)
    return exp_names, table_kwargs, table_filters
