import kdtpp.experiments as exp
import kdai._logging
import polars as pl


models = [
    "gptvar-6-4-32",
    "gptvar-2-4-16",
    "rnn-cat",
    "gpt-6-4-32-logmix64",
    "gpt-2-4-16-logmix64",
    "rnn-logmix",
    "gpt-2-4-16-const",
    "rnn-const",
    "gpt-2-4-16-exp",
    "rnn-exp",
    "gpt-2-4-16-nn",
    "rnn-nn",
    "zuo-thp-1",
    "zuo-thp-0",
]


def train():
    ver_parts = kdai._logging.version_labels_from_script_dir()
    out_dir = exp.start_logging(ver_parts)

    lr_map = exp.load_lr_map(pl.read_csv("./data/lrmap/combined_lrmap.csv"))

    for model_name in models:
        exp.train(
            out_dir,
            exp.RandProc.for_model(model_name),
            "train-loss",
            use_early_stopping=True,
            lr_map=lr_map,
        )
        exp.train(
            out_dir,
            # Excluding Stack Overflow, as data may not be downloaded.
            # If Stack Overflow is download, instead call:
            # exp.Classic.for_model_and_ds(model_name),
            exp.Classic.for_model_and_ds(model_name, "nyc-taxi"),
            "train-loss",
            use_early_stopping=True,
            lr_map=lr_map,
        )


if __name__ == "__main__":
    train()
