"""
Call this script to train models and have their checkpoints saved out to
`./out/test/trainable_factory`. The weights are used in a few tests. 
`conftest.py` has some factories that create trainables using these weights.
"""

import kdai._logging
import kdai.train
import kdtpp.experiments as exp
import logging
from pathlib import Path

_logger = logging.getLogger(__name__)

OUT_DIR = Path("./out/test/trainable_factory")


def train(exp_class, model_name, ds_name, train_len=None):
    """Train a model to be used in tests."""
    if train_len is None:
        # Just use 1 length, the largest one.
        train_len = exp_class.default_train_lens(ds_name)[0]
    run_itr = exp_class.for_model_and_ds(model_name, ds_name, [train_len])
    runs = list(run_itr)
    run_spec, ds_mgr_fn = runs[0]
    model_dir = OUT_DIR / run_spec.model_dir
    if model_dir.exists():
        _logger.info(f"Skipping {model_dir}. Model dir already exists.")
        return
    trainable = exp.trainable_fns[model_name](ds_mgr_fn, eval_mode="train-loss")
    assert not model_dir.exists()
    _logger.info(f"Training {model_name} on {ds_name}")
    kwargs = exp.get_train_args(ds_name, model_name)
    # Just do 1 epoch
    kwargs["n_epochs"] = 1  # run_spec.n_epochs
    kwargs["batch_size"] = run_spec.batch_size
    kwargs["steps_til_eval"] = run_spec.steps_til_eval
    kdai.train.train(
        trainable,
        out_dir=model_dir,
        early_stopper=kdai.train.EarlyStopper(
            min_steps=5000, step_patience=50
        ),
        **kwargs,
    )
    return trainable


def train_all():
    train(exp.Classic, "zuo-thp-0", "so-badges")
    train(exp.Classic, "zuo-thp-0", "so-badges-hours")


if __name__ == "__main__":
    kdai._logging.setup_logging(logging.INFO)
    train_all()
