import kdtpp.experiments as exp
import kdai._logging
import polars as pl


models = [
    "gptvar-6-4-32",
    "gptvar-2-4-16",
    "rnn-cat",
    "gpt-6-4-32-logmix64",
    "gpt-2-4-16-logmix64",
    "rnn-logmix",
    "gpt-6-4-32-const",
    "gpt-2-4-16-const",
    "rnn-const",
    "gpt-6-4-32-exp",
    "gpt-2-4-16-exp",
    "rnn-exp",
    "gpt-6-4-32-nn",
    "gpt-2-4-16-nn",
    "rnn-nn",
    "zuo-thp-1",
    "zuo-thp-0",
]


def sweep(model_name, out_dir):
    """Run a sweep. Common function called by local and slurm runs."""
    n_runs = 8
    n_lr_steps = int(1e3)
    n_warm_up_steps = int(1e3)

    # We don't need many batch sizes as we won't be using many training lengths.
    batch_sizes = [32, 64, 128]
    df = exp.lr_sweep(
        exp.Baseline.for_model(model_name),
        out_dir,
        n_runs,
        n_lr_steps=n_lr_steps,
        lr_min=5e-6,
        lr_max=5e-2,
        batch_sizes=batch_sizes,
        warm_up_steps=n_warm_up_steps,
    )
    df.write_parquet(str(out_dir / f"{model_name}_lrsweep.parquet"))
    return df


def all_sweeps():
    ver_parts = kdai._logging.version_labels_from_script_dir()
    out_dir = exp.start_logging(ver_parts)

    dfs = []
    for model_name in models:
        df = sweep(model_name, out_dir)
        dfs.append(df)
    df = pl.concat(dfs, how="vertical")
    df.write_parquet(out_dir / f"all_lrsweep.parquet")


if __name__ == "__main__":
    all_sweeps()
