import kdtpp.experiments as exp
import kdai._logging
from pathlib import Path
import polars as pl
import logging

_logger = logging.getLogger(__name__)

# Point to output directory of the training run.
train_dir = Path("./out/exp/2/2/0")
ver_parts = kdai._logging.version_labels_from_script_dir()
out_dir = exp.start_logging(ver_parts)

models = [
    "tf-2-4-16-logmix64",
    "tf-6-4-32-logmix64",
    "tf-2-4-16-discrete",
    "tf-6-4-32-discrete",
]


def eval(use_test_ds):
    # Load the data once.
    recs = exp.Spikes2.load_recs()
    stats = exp.Spikes2.stats(recs)
    split = exp.Spikes2.splits(recs)

    dfs = []
    for i, m in enumerate(models):
        df = exp.eval_for_spikesv2(
            train_dir,
            exp.Spikes2.for_model(m, split, stats),
            pred_stride=80,
            sigma_ms=60,
            use_test_ds=use_test_ds,
        )
        _logger.info(f"({i+1}/{len(models)}) {m} evaluation done")
        dfs.append(df)
    df = pl.concat(dfs, how="vertical")
    filename = f"spikes_{['val','test'][use_test_ds]}_metrics.parquet"
    df.write_parquet(out_dir / filename)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    eval(use_test_ds=False)
