#!/usr/bin/env python3
"""
launch_experiments.py


  * run them locally one-by-one               : --local
  * dispatch a single config by task index    : --task_id N
  * submit the whole sweep as a SLURM array   : (default)

The heavy lifting is delegated to train_model.evaluate_expirement().
"""
import argparse
import copy
import itertools
import os
import subprocess
import sys
import tempfile
from pathlib import Path

import numpy as np
import yaml

# ----------------------------------------------------------------------------- #
#  Settings                                                                     #
# ----------------------------------------------------------------------------- #
to_run = [""]        # default experiment filter
YAML_DEFAULT = "experiments.yaml"  # location of your config file

# ----------------------------------------------------------------------------- #
#  Helpers                                                                      #
# ----------------------------------------------------------------------------- #
def expand_values(val,constants):
    if isinstance(val, dict) and "list" in val:
        return [val["list"]]
    if isinstance(val, list):
        return val
    if isinstance(val, dict) and set(val) == {"start", "end", "steps"}:
        s, e, st = val["start"], val["end"], val["steps"]
        return np.linspace(s, e, st).tolist()
    if val in constants:
        return constants[val]        
    return [val]


def load_and_expand(path=YAML_DEFAULT, selected_exps=None):
    with open(path, "r") as f:
        data = yaml.safe_load(f)

    default_cfg = data["default_config"]          # raise if missing
    constants = data["constants"]
    exps        = data.get("experiments", [])



    all_cfgs = []
    for exp in exps:
        name = exp.get("name", "<unnamed>")

        # experiment‐name filter
        if selected_exps and "all" not in selected_exps and name not in selected_exps and ("sections" not in selected_exps or ("sections" in selected_exps and "section" not in name) ):
            continue

        updates = exp.get("config_update", {})
        keys    = list(updates)
        grids   = [expand_values(updates[k],constants) for k in keys]

        for combo in itertools.product(*grids):
            cfg = copy.deepcopy(default_cfg)
            for k, v in zip(keys, combo):
                cfg[k] = v

            if "run_id" not in cfg:
                for r in range(cfg["runs"]):

                    cfg_2 = copy.deepcopy(cfg)

                    cfg_2["run_id"] = r
                    cfg_2["experiment_name"] = "run_"+str(r)+"_"+name

                    all_cfgs.append(cfg_2)
            else:
                cfg["experiment_name"] = "run_"+str(cfg["run_id"])+"_"+name

                all_cfgs.append(cfg)

    return all_cfgs


# ----------------------------------------------------------------------------- #
#  Where the *real* training happens                                            #
# ----------------------------------------------------------------------------- #
def run_experiment(cfg):
    """
    Run a single experiment configuration via train_model.evaluate_expirement().
    """
    import train_model  # local import so the launcher starts up quickly

    print(f"\n>>> Running experiment '{cfg['experiment_name']}' with cfg diff:")
    for k, v in cfg.items():
        if k != "experiment_name":
            print(f"    {k}: {v}")

    train_model.evaluate_expirement(cfg)  # ← the magic call


# ----------------------------------------------------------------------------- #
#  SLURM helper                                                                 #
# ----------------------------------------------------------------------------- #
def write_and_submit_sbatch(num_tasks, script_name, yaml_file, experiments):
    fd, sbatch_path = tempfile.mkstemp(suffix=".sbatch", text=True)
    os.close(fd)

    exp_list = " ".join(experiments)

    # Determine memory allocation based on experiments
    # Large models need more memory
    if any("100m" in exp or "1b" in exp for exp in experiments):
        memory = "60G"  # 40GB for very large models
        time = "02:00:00"
    elif any("50m" in exp for exp in experiments):
        memory = "40G"  # 20GB for large models
        time = "02:00:00"
    else:
        memory = "10G"  # 10GB for smaller models
        time = "00:50:00"

    sbatch_contents = f"""#!/bin/bash
#SBATCH -t {time}
#SBATCH --mem {memory}
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:a100:1
#SBATCH --job-name=hebb_sweep
#SBATCH --array=0-{num_tasks - 1}
#SBATCH --output=slurm_out/%x_%A_%a.out
#SBATCH --error=slurm_out/%x_%A_%a.err

echo "Launching ${{SLURM_ARRAY_TASK_ID}} on $(hostname)"
cd {os.getcwd()}

python {script_name} -f {yaml_file} -e {exp_list} --task_id ${{SLURM_ARRAY_TASK_ID}}
"""


    Path(sbatch_path).write_text(sbatch_contents)
    print(f">>> Submitted SLURM array 0–{num_tasks-1} via {sbatch_path}")
    subprocess.run(["sbatch", sbatch_path], check=True)


# ----------------------------------------------------------------------------- #
#  Main CLI                                                                     #
# ----------------------------------------------------------------------------- #
if __name__ == "__main__":
    ap = argparse.ArgumentParser(
        description="Expand experiments.yaml and run configs locally or on SLURM."
    )
    ap.add_argument("-e", "--experiments", nargs="+", default=to_run,
                    help='Names of experiments to run (or "all")')
    ap.add_argument("-f", "--file", default=YAML_DEFAULT,
                    help="Path to the experiments YAML file")
    ap.add_argument("-l", "--local", action="store_true",
                    help="Run all configs sequentially in this process")
    ap.add_argument("--task_id", type=int,
                    help="Run exactly this config index (for SLURM array tasks)")
    args = ap.parse_args()

    # 1) Expand YAML → list of configs
    configs = load_and_expand(path=args.file, selected_exps=args.experiments)
    total   = len(configs)
    print(f">>> {total} configs generated")

    if total == 0:
        sys.exit("No configs match the filter – nothing to do.")

    # 2) task_id: just run that one config
    if args.task_id is not None:
        if not (0 <= args.task_id < total):
            sys.exit(f"--task_id {args.task_id} out of range (0–{total-1})")
        run_experiment(configs[args.task_id])
        sys.exit(0)

    # 3) --local: loop over all configs
    if args.local:
        for idx, cfg in enumerate(configs):
            print(f"\n=== Local run {idx+1}/{total} ===")
            run_experiment(cfg)
        sys.exit(0)

    # 4) Submit as SLURM array
    script_name = os.path.basename(sys.argv[0])

    write_and_submit_sbatch(
        num_tasks=total,
        script_name=script_name,
        yaml_file=args.file,
        experiments=args.experiments
    )