import kdtpp.experiments as exp
import kdai._logging
from pathlib import Path
import polars as pl
import logging

_logger = logging.getLogger(__name__)

# Point to output directory of the training run.
train_dir = Path("./out/exp/1/2/0")
ver_parts = kdai._logging.version_labels_from_script_dir()
out_dir = exp.start_logging(ver_parts)


models = [
    "gptvar-6-4-32",
    "gptvar-2-4-16",
    "rnn-cat",
    "gpt-6-4-32-logmix64",
    "gpt-2-4-16-logmix64",
    "rnn-logmix",
    "gpt-6-4-32-const",
    "gpt-2-4-16-const",
    "rnn-const",
    "gpt-6-4-32-exp",
    "gpt-2-4-16-exp",
    "rnn-exp",
    "gpt-6-4-32-nn",
    "gpt-2-4-16-nn",
    "rnn-nn",
    "zuo-thp-1",
    "zuo-thp-0",
]

def eval(use_test_ds):
    dfs = []
    filename = f"rand_{['val','test'][use_test_ds]}_metrics"
    for i, m in enumerate(models):
        df = exp.eval(
            train_dir,
            exp.Baseline.for_model(m),
            use_test_ds=use_test_ds,
            batch_size=512,
        )
        df.write_parquet(out_dir / f"{filename}_{m}.parquet")
        _logger.info(f"({i+1}/{len(models)}) {m} evaluation done")
        dfs.append(df)

    df = pl.concat(dfs, how="vertical")
    df.write_parquet(out_dir / f"{filename}.parquet")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    eval(use_test_ds=False)
