import os
import argparse

from metric import Metric, str_to_metric
from pile_client import get_address_path
import src.pile_server as server
from src.launch_utils import generate_base_command


def main(args):
    metric = str_to_metric(args.metric)
    address_path = get_address_path(f"servers/{f'{args.prefix}_' if args.prefix != '' else ''}addresses.txt", metric, args.normalized)

    # Clear addresses of previous servers
    open(address_path, "w+").close()

    cmds = []

    for i in range(args.num):
        flags = {
            "address_path": address_path,
            "data_file": f"{i:02d}.jsonl",
            "num_servers": args.num_servers,
            "normalized": args.normalized,
            "metric": metric.value,
            "optimized": args.optimized,
        }

        base_cmd = generate_base_command(server, flags=flags)

        sbatch_cmd = (
            "sbatch "
            + f"--time={args.num_hours}:00:00 "
            + f"--mem-per-cpu={args.mem_per_cpu}G "
            + f"-n {args.num_cpus} "
            + (
                ""
                if args.num_gpus == 0
                else f"--gpus={args.num_gpus} --gres=gpumem:{args.gpumem}g "
            )
            + f"--wrap='{base_cmd}'"
        )

        cmds.append(sbatch_cmd)

    for cmd in cmds:
        os.system(cmd)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num", type=int, default=30, help="Number of files to serve")
    parser.add_argument("--num-servers", type=int, default=6)
    parser.add_argument("--normalized", action="store_true")
    parser.add_argument("--metric", type=str, default=Metric.ABSIP.value)
    parser.add_argument("--prefix", type=str, default="")
    parser.add_argument("--optimized", action="store_true", help="Enable RAM optimizations")
    parser.add_argument("--num-cpus", type=int, default=8)
    parser.add_argument("--num-hours", type=int, default=24)
    parser.add_argument("--mem-per-cpu", type=int, default=20)
    parser.add_argument("--num-gpus", type=int, default=0)
    parser.add_argument("--gpumem", type=int, default=10)
    args = parser.parse_args()
    main(args)
