import kdtpp.experiments as exp
import kdai._logging
import kdai.train
import kdtpp.prob
import torch
from pathlib import Path
import polars as pl
import logging

_logger = logging.getLogger(__name__)

# Point to output directory of the training run.
train_dir = Path("./out/exp/0/2/0")
ver_parts = kdai._logging.version_labels_from_script_dir()
out_dir = exp.start_logging(ver_parts)


model_names = [
    "gpt-6-4-32-f",
    "gpt-2-4-16-f",
    "gpt-6-4-32-logmix",
    "gpt-2-4-16-logmix",
    "gpt-2-4-16-const",
    "gpt-2-4-16-exp",
    "gpt-2-4-16-nn",
]


def eval(use_test_ds):
    filename = f"cyclic_{['val','test'][use_test_ds]}_metrics"
    dfs = []
    for i, m in enumerate(model_names):
        df = exp.eval(
            train_dir,
            exp.Cyclic.for_model(m),
            use_test_ds=use_test_ds,
            batch_size=512,
        )
        _logger.info(f"({i+1}/{len(model_names)}) {m} evaluation done")
        df.write_parquet(out_dir / f"{filename}_{m}.parquet")
        dfs.append(df)
    df = pl.concat(dfs, how="vertical")
    df.write_parquet(out_dir / f"{filename}.parquet")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    eval(use_test_ds=False)
