import kdtpp.experiments as exp
import kdai._logging
import kdai.train
import torch
from pathlib import Path
import polars as pl
import math
import polars as pl


model_names = [
    "gpt-6-4-32-f",
    "gpt-2-4-16-f",
    "gpt-6-4-32-logmix",
    "gpt-2-4-16-logmix",
    "gpt-2-4-16-const",
    "gpt-2-4-16-exp",
    "gpt-2-4-16-nn",
]


def single_sweep(model_name, out_dir, version: str):
    """Run a sweep."""
    n_runs = 8
    n_lr_steps = int(1e3)
    n_warm_up_steps = int(1e3)

    torch._dynamo.config.cache_size_limit = 15

    max_batch_size_log = 11
    batch_sizes = [2**i for i in reversed(range(2, max_batch_size_log + 1))]
    # Factor of 2 adds (generous) buffer for skipped lr steps.
    # As lr_find reuses the RunSpec infrastructure, which requires training
    # length to divide the hard-coded num_samples, we make our train length
    # a power of two.
    model_in_len = 128
    samples_per_seq = 1024 - model_in_len
    train_len = 2 ** math.ceil(
        math.log2(max(batch_sizes) * n_lr_steps * 2 / samples_per_seq)
    )
    df = exp.lr_sweep(
        exp.Cyclic.for_model(model_name, n_train_seqs=[train_len]),
        out_dir,
        n_runs,
        n_lr_steps=n_lr_steps,
        lr_min=5e-6,
        lr_max=5e-2,
        batch_sizes=batch_sizes,
        warm_up_steps=n_warm_up_steps,
    )
    df.write_parquet(str(out_dir / f"{model_name}_lrsweep_{version}.parquet"))
    return df


def sweep():
    """Run an LR sweep for all models."""
    # torch.cuda.set_device(kdai.train.gpus_ordered_by_mem_used(exclude={1})[0])
    ver_parts = kdai._logging.version_labels_from_script_dir()
    out_dir = exp.start_logging(ver_parts)
    ver_parts.append(out_dir.name)
    version = "_".join(ver_parts)

    dfs = []
    for model_name in model_names:
        df = single_sweep(model_name, out_dir, version)
        dfs.append(df)
    df = pl.concat(dfs, how="vertical")
    df.write_parquet(out_dir / f"all_lrsweep_{version}.parquet")


if __name__ == "__main__":
    sweep()
