import kdtpp.experiments as exp
import kdai._logging


models = [
    "tf-2-4-16-logmix64",
    "tf-6-4-32-logmix64",
    "tf-2-4-16-discrete",
    "tf-6-4-32-discrete",
]

def train():
    ver_parts = kdai._logging.version_labels_from_script_dir()
    out_dir = exp.start_logging(ver_parts)

    # Load the data once.
    recs = exp.Spikes2.load_recs()
    stats = exp.Spikes2.stats(recs)
    split = exp.Spikes2.splits(recs)

    for m in models:
        exp.train(
            out_dir,
            exp.Spikes2.for_model(m, split, stats),
            "train-loss",
            use_early_stopping=True,
        )

if __name__ == "__main__":
    train()
