import argparse
import copy
import yaml

from lkb.benchmark import BenchmarkConfig, run_benchmark


def parse_list(s: str, typ=float):
    return [typ(tok.strip()) for tok in s.split(",") if tok.strip()]


def parse_algos(s: str):
    return [tok.strip() for tok in s.split(",") if tok.strip()]


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base_config", required=True, help="YAML base config (e.g., medium Kernel-A)")
    ap.add_argument("--vary_eta", type=str, default=None, help="Comma list for eta, e.g., '0,0.1,1,5,10'")
    ap.add_argument("--vary_n", type=str, default=None, help="Comma list for n_users, e.g., '20,50,100,200'")

    # Behavior: if --algos is omitted, keep whatever is in the base YAML.
    # If you want the benchmark's full default algorithm list, pass --use_default_algos.
    ap.add_argument(
        "--algos",
        type=str,
        default=None,
        help=("Comma-separated algorithms to include. "
              "If omitted, the algorithms from the base config are preserved."),
    )
    ap.add_argument(
        "--use_default_algos",
        action="store_true",
        help=("Ignore 'algos' from base config and use the benchmark's full default list "
              "(only if --algos is not provided)."),
    )

    ap.add_argument("--suffix", type=str, default="abl", help="Prefix for experiment names")
    args = ap.parse_args()

    with open(args.base_config, "r") as f:
        base = yaml.safe_load(f)

    # Build the base BenchmarkConfig from YAML
    cfg0 = BenchmarkConfig(**base)

    if args.algos is not None:
        cfg0.algos = parse_algos(args.algos)
    elif args.use_default_algos:
        cfg0.algos = None

    # Sweep eta
    if args.vary_eta:
        for v in parse_list(args.vary_eta, typ=float):
            cfg = copy.deepcopy(cfg0)
            cfg.exp_name = f"{args.suffix}_eta_{v:g}"
            cfg.eta = float(v)
            run_benchmark(cfg)

    # Sweep n_users
    if args.vary_n:
        for v in parse_list(args.vary_n, typ=int):
            cfg = copy.deepcopy(cfg0)
            cfg.exp_name = f"{args.suffix}_n_{v:d}"
            cfg.n_users = int(v)
            run_benchmark(cfg)


if __name__ == "__main__":
    main()
