import time

import wandb

_DATASET_TO_PARAMETERS = {
    "cola": {"metric": "matthews_correlation"},
    "mrpc": {"metric": "accuracy_f1"},
    "rte": {"metric": "accuracy", "epochs": 200, "patience": 20},
    "sst2": {"metric": "accuracy"},
    "mnli": {"metric": "accuracy", "epochs": 5, "patience": 2},
    "qnli": {"metric": "accuracy", "epochs": 10, "patience": 2},
    "qqp": {"metric": "accuracy_f1", "epochs": 5, "patience": 2},
    "stsb": {"metric": "pearson_spearman"},
    "boolq": {"metric": "accuracy"},
    "cb": {"metric": "accuracy_f1", "epochs": 500, "patience": 100},
    "copa": {"metric": "accuracy", "epochs": 500, "patience": 100},
    "multirc": {"metric": "accuracy_f1", "epochs": 10, "patience": 4},
    "record": {"metric": "accuracy_f1", "epochs": 1, "patience": 1},
    "wic": {"metric": "accuracy", "epochs": 500, "patience": 20},
    "wsc": {"metric": "accuracy", "epochs": 500, "patience": 100},
}

_BASE_SHARED_PARAMETERS = {
    "batch-size": {"distribution": "categorical", "values": [16, 64]},
    "lr": {"values": ["1e-4", "5e-4", "5e-3", "1e-3"]},
    "seed": {"distribution": "categorical", "values": [42, 43, 44, 45, 46]},
}

_LARGE_SHARED_PARAMETERS = {
    "batch-size": {"distribution": "categorical", "values": [16, 32, 64, 128]},
    "lr": {"values": ["1e-5", "5e-5", "1e-4", "5e-4", "5e-3", "1e-3", "2e-3", "1e-2"]},
    "seed": {"distribution": "categorical", "values": [42, 43, 44, 45, 46]},
}

_DEBERTA_SHARED_PARAMETERS = {
    "batch-size": {"distribution": "categorical", "values": [8, 16]},
    "lr": {"values": ["1e-5", "5e-5", "1e-4", "3e-4", "5e-4", "1e-3", "2e-3", "5e-3"]},
    "seed": {"distribution": "categorical", "values": [42]},
}

_SPECIFIC_PARAMETERS = {
    "zero": {
        "prompt-rank": {
            "distribution": "categorical",
            "values": [5, 10, 25, 30, 50],
        }
    },
    "zero_fc": {
        "prompt-rank": {
            "distribution": "categorical",
            "values": [32, 64, 128, 256, 512],
        }
    },
    "adapter": {
        "prompt-rank": {
            "distribution": "categorical",
            "values": [16, 32, 64, 128, 256],
        }
    },
    "lora": {
        "prompt-rank": {
            "distribution": "categorical",
            "values": [2, 4, 16, 32, 64],
        }
    },
    "prefix": {
        "prompt-length": {
            "distribution": "categorical",
            "values": [5, 10, 20, 50, 100],
        }
    },
    "prompt": {
        "prompt-length": {
            "distribution": "categorical",
            "values": [5, 10, 20, 50, 100],
        }
    },
    "bitfit": {},
    "full": {"lr": {"values": ["1e-5", "5e-5", "1e-4", "5e-4", "5e-3"]}},
}


def format_command(dataset, training_type, model):
    command = [
        "${env}",
        "python3",
        "tools/train.py",
        "${args}",
        "--name",
        None,
        "--pretrain",
        None,
        "--training-type",
        None,
        "--fp16",
        "--hidden-dropout",
        0.1,
    ]
    command[5] = dataset
    command[9] = training_type
    command[7] = model
    if "epochs" in _DATASET_TO_PARAMETERS[dataset]:
        command = command + [
            "--epochs",
            _DATASET_TO_PARAMETERS[dataset]["epochs"],
            "--num-patience-steps",
            _DATASET_TO_PARAMETERS[dataset]["patience"],
        ]
    if dataset in ("boolq", "cb", "multirc", "copa", "record", "wsc", "wic"):
        command = command + ["--seq-len", 384]

    return command


def generate_config(dataset, training_type, model):
    command = format_command(dataset, training_type, model)
    name = f"{model}-{dataset}-{training_type}"
    if "deberta" in model:
        parameters = {
            **_DEBERTA_SHARED_PARAMETERS,
            **_SPECIFIC_PARAMETERS[training_type],
        }
    elif dataset != "rte":
        parameters = {**_BASE_SHARED_PARAMETERS, **_SPECIFIC_PARAMETERS[training_type]}
    else:
        parameters = {**_LARGE_SHARED_PARAMETERS, **_SPECIFIC_PARAMETERS[training_type]}

    return {
        "command": command,
        "method": "grid",
        "name": name,
        "project": "aot",
        "metric": {
            "name": _DATASET_TO_PARAMETERS[dataset]["metric"],
            "goal": "maximize",
        },
        "parameters": parameters,
    }


if __name__ == "__main__":
    ids = []
    for model in ["roberta-base", "roberta-large"]:
        for dataset in sorted(list(_DATASET_TO_PARAMETERS.keys())):
            for training_type in sorted(list(_SPECIFIC_PARAMETERS.keys())):
                if training_type in ("bitfit",) and dataset not in (
                    "boolq",
                    "cb",
                    "copa",
                    "multirc",
                    "wic",
                    "wsc",
                    "record"
                    # "rte",
                ):
                    config = generate_config(dataset, training_type, model)
                    ids.append(wandb.sweep(config, project="aot", entity="aot"))
                    time.sleep(2)  # to make sweeps appear in wandb in right order

    print(" ; \\ \n".join([f"wandb agent tlab/p-free/{id}" for id in ids]))
