import argparse

# from launchers.utils import CHUNKED_DATASETS
from src.launch_utils import generate_base_command, generate_run_commands
import src.eval_tttlm as experiment

NAME = "Baselines"

RUN_CONFIG = {
    "num-cpus": 4,
    "num-gpus": 1,
    "num-hours": 24,
    "mem": 20,
    "gpumem": 24,
}

SELECTED_DATASETS = [
    "pile_arxiv",
    "pile_enron",
    "pile_github",
    "pile_nih-exporter",
]
START_INDEX = 0

applicable_configs = {
    "algs": [
        "VTL",
        "NearestNeighbour",
        "ONN",
        "UncertaintySampling",
    ],
    "models": [
        # "gpt2",
        # "gpt2-large",
        # "google/gemma-2-2b",
        "microsoft/Phi-3-mini-4k-instruct",
        # "microsoft/Phi-3.5-mini-instruct",
    ],
}


class Config:
    def __init__(self, model):
        if model in ["gpt2", "gpt2-large"]:
            self.max_length = 1024
            self.stride = 1024
            self.learning_rate = 5e-5
        elif model in ["gptneo"]:
            self.max_length = 2048
            self.stride = 2048
            self.learning_rate = 5e-6
        elif model in [
            "microsoft/Phi-3-mini-4k-instruct",
            "microsoft/Phi-3.5-mini-instruct",
        ]:
            self.max_length = 4096
            self.stride = 4096
            self.learning_rate = 5e-5
        elif model in ["google/gemma-2-2b"]:
            self.max_length = 1024
            self.stride = 1024
            self.learning_rate = 5e-5
        else:
            raise AttributeError("Unsupported model")


def main(args):
    command_list = []
    # for dataset in applicable_configs["datasets"]:
    for chunk, model in enumerate(applicable_configs["models"]):
        for alg in applicable_configs["algs"]:
            config = Config(model)
            flags = {
                "name": NAME,
                "seed": 1216,
                "fraction_of_test_set": 0.01,
                "absolute_test_set_size": 10,
                "num_neighbors": 50,
                "k": 200,
                "gradient_steps": 1,
                "batch_size": 1,
                "llambda": 0.1,
                "acquisition_function": alg,
                # "dataset": dataset,
                "dataset": ",".join(SELECTED_DATASETS),
                "model": model,
                "tokenizer": model,
                "embedding_model_checkpoint": "models/roberta-large-pile-lr2e-5-bs16-8gpu/checkpoint-1700000",
                "max_length": config.max_length,
                "stride": config.stride,
                "learning_rate": config.learning_rate,
                "address_path": f"servers/models{chunk}_addresses.txt",
                "results_dir": "results/",
                "metric": "ABSIP",
                "normalized": True,
                "start_index": START_INDEX,
            }
            cmd = generate_base_command(experiment, flags=flags)
            command_list.append(cmd)

    generate_run_commands(
        command_list,
        num_cpus=args.num_cpus,
        num_gpus=args.num_gpus,
        mode="euler",
        num_hours=args.num_hours,
        promt=True,
        mem=args.mem,
        gpumem=args.gpumem,
        begin=args.begin,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-cpus", type=int, default=RUN_CONFIG["num-cpus"])
    parser.add_argument("--num-gpus", type=int, default=RUN_CONFIG["num-gpus"])
    parser.add_argument("--num-hours", type=int, default=RUN_CONFIG["num-hours"])
    parser.add_argument("--mem", type=int, default=RUN_CONFIG["mem"])
    parser.add_argument("--gpumem", type=int, default=RUN_CONFIG["gpumem"])
    parser.add_argument("--begin", type=str, default="now")
    args = parser.parse_args()
    main(args)
