import argparse
# from launchers.utils import CHUNKED_DATASETS
from src.launch_utils import generate_base_command, generate_run_commands
import src.eval_tttlm as experiment

NAME = "Ablation_Study_Lambda"

MODEL = "gpt2"
# "gpt2-large",
# "google/gemma-2-2b",
# "microsoft/Phi-3-mini-4k-instruct"

RUN_CONFIG = {
    "gpt2": {
        "num-cpus": 4,
        "num-gpus": 1,
        "num-hours": 24,
        "mem": 20,
        "gpumem": 20,
    }
}

CHUNKED_DATASETS = [
  [
    "pile_arxiv",
    "pile_enron",
    "pile_europarl",
    "pile_freelaw",
    "pile_philpapers",
    "pile_pubmed-abstracts",
    "pile_uspto",
  ],
  [
    "pile_github",
    "pile_dm-mathematics",
    "pile_nih-exporter",
    "pile_hackernews",
  ],
  [
    "pile_pile-cc",
    "pile_pubmed-central",
  ],
  [
    "pile_wikipedia",
  ],
  [
    "pile_stackexchange",
  ],
]
START_INDEX = 0

applicable_configs = {
    "k": [200],
    "lambda": [
        1e-12,
        1e-8,
        1e-4,
        1e-2,
        1e-1,
        1.0,
        1e1,
        1e2,
        1e4,
    ],
    "algs": ["VTL"],
    # "datasets": [",".join(DATASETS1)],  # ["pile_enron,pile_uspto,pile_arxiv"],
    "chunks": [
        0,
        1,
        2,
        3,
        4,
    ],
}


class Config:

    def __init__(self, model):
        if model in ["gpt2", "gpt2-large"]:
            self.max_length = 1024
            self.stride = 1024
            self.learning_rate = 5e-5
        elif model in ["gptneo"]:
            self.max_length = 2048
            self.stride = 2048
            self.learning_rate = 5e-6
        elif model in ["microsoft/Phi-3-mini-4k-instruct"]:
            # Use 4096 if GPU has enough memory
            self.max_length = 4096  # 4096
            self.stride = 1024  # 4096
            self.learning_rate = 5e-5
        else:
            raise AttributeError("Unsupported model")


def main(args):
    command_list = []
    for chunk in applicable_configs["chunks"]:
        for k in applicable_configs["k"]:
            for llambda in applicable_configs["lambda"]:
                for alg in applicable_configs["algs"]:
                    config = Config(MODEL)
                    flags = {
                        "name": NAME,
                        "seed": 1216,
                        "fraction_of_test_set": 0.01,
                        "num_neighbors": 50,
                        "k": k,
                        "gradient_steps": 1,
                        "batch_size": 1,
                        "llambda": llambda,
                        "acquisition_function": alg,
                        "dataset": ",".join(CHUNKED_DATASETS[chunk]),
                        "model": MODEL,
                        "tokenizer": MODEL,
                        "embedding_model_checkpoint": "models/roberta-large-pile-lr2e-5-bs16-8gpu/checkpoint-1700000",
                        "max_length": config.max_length,
                        "stride": config.stride,
                        "learning_rate": config.learning_rate,
                        "address_path": f"servers/lambda{chunk}_addresses.txt",
                        "results_dir": "results/",
                        "metric": "ABSIP",
                        "normalized": True,
                        "start_index": START_INDEX,
                    }
                    cmd = generate_base_command(
                        experiment, flags=flags
                    )
                    command_list.append(cmd)

    generate_run_commands(
        command_list,
        num_cpus=args.num_cpus,
        num_gpus=args.num_gpus,
        mode="euler",
        num_hours=args.num_hours,
        promt=True,
        mem=args.mem,
        gpumem=args.gpumem,
        begin=args.begin,
    )


if __name__ == "__main__":
    run_config = RUN_CONFIG[MODEL]
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-cpus", type=int, default=run_config["num-cpus"])
    parser.add_argument("--num-gpus", type=int, default=run_config["num-gpus"])
    parser.add_argument("--num-hours", type=int, default=run_config["num-hours"])
    parser.add_argument("--mem", type=int, default=run_config["mem"])
    parser.add_argument("--gpumem", type=int, default=run_config["gpumem"])
    parser.add_argument("--begin", type=str, default="now")
    args = parser.parse_args()
    main(args)
