import kdtpp.experiments as exp
import kdai._logging
import logging
from pathlib import Path
import polars as pl

_logger = logging.getLogger(__name__)


def train():
    ver_parts = kdai._logging.version_labels_from_script_dir()
    out_dir = exp.start_logging(ver_parts)
    ver_parts.append(out_dir.name)

    model_names = [
        "gpt-6-4-32-f",
        "gpt-2-4-16-f",
        "gpt-6-4-32-logmix64",
        "gpt-2-4-16-logmix64",
        "gpt-2-4-16-const",
        "gpt-2-4-16-exp",
        "gpt-2-4-16-nn",
    ]

    lr_map = exp.load_lr_map(pl.read_csv("./data/lrmap/combined_lrmap.csv"))

    for model_name in model_names:
        exp.train(
            out_dir,
            exp.Cyclic.for_model(model_name),
            "train-loss",
            use_early_stopping=True,
            lr_map=lr_map,
        )


if __name__ == "__main__":
    train()
