import kdtpp.experiments as exp
import kdai._logging
import torch
import polars as pl
import math
import polars as pl


models = [
    "gptvar-6-4-32",
    "gptvar-2-4-16",
    "rnn-cat",
    "gpt-6-4-32-logmix64",
    "gpt-2-4-16-logmix64",
    "rnn-logmix",
    "gpt-2-4-16-const",
    "rnn-const",
    "gpt-2-4-16-exp",
    "rnn-exp",
    "gpt-2-4-16-nn",
    "rnn-nn",
    "zuo-thp-1",
    "zuo-thp-0",
]


def sweep(model_name, out_dir):
    """Run a sweep. Common function called by local and slurm runs."""
    n_runs = 8
    n_lr_steps = int(1e3)
    n_warm_up_steps = int(1e3)
    # We are hitting recompiles. It looks like it's the batchsize difference, so
    # allow >len(batch_sizes) recompiles.
    torch._dynamo.config.cache_size_limit = 15

    max_batch_size_log = 11  # max (inc) is 2048, for all models.
    batch_sizes = [2**i for i in reversed(range(2, max_batch_size_log + 1))]
    # Factor of 2 adds (generous) buffer for skipped lr steps.
    # As lr_find reuses the RunSpec infrastructure, which requires training
    # length to divide the hard-coded num_samples, we make our train length
    # a power of two.
    train_len = 2 ** math.ceil(math.log2(max(batch_sizes) * n_lr_steps * 2))
    df1 = exp.lr_sweep(
        exp.RandProc.for_model(model_name, [train_len]),
        out_dir,
        n_runs,
        n_lr_steps=n_lr_steps,
        lr_min=5e-6,
        lr_max=5e-2,
        batch_sizes=batch_sizes,
        warm_up_steps=n_warm_up_steps,
    )
    df2 = exp.lr_sweep(
        # exp.Classic.for_model_and_ds(model_name, [train_len]),
        # Excluding Stack Overflow, as data may not be downloaded.
        exp.Classic.for_model_and_ds(model_name, "nyc-taxi", [train_len]),
        out_dir,
        n_runs,
        n_lr_steps=n_lr_steps,
        lr_min=5e-6,
        lr_max=5e-2,
        batch_sizes=batch_sizes,
        warm_up_steps=n_warm_up_steps,
    )
    df = pl.concat([df1, df2], how="vertical")
    df.write_parquet(str(out_dir / f"{model_name}_lrsweep.parquet"))
    return df


def all_sweeps():
    ver_parts = kdai._logging.version_labels_from_script_dir()
    out_dir = exp.start_logging(ver_parts)

    dfs = []
    for model_name in models:
        df = sweep(model_name, out_dir)
        dfs.append(df)
    df = pl.concat(dfs, how="vertical")
    df.write_parquet(out_dir / f"all_lrsweep.parquet")


if __name__ == "__main__":
    all_sweeps()
