import pytest
import kdai
import kdai._logging
import kdai.train
import kdtpp.experiments as experiments
import polars as pl

OUTPUT_ROOT_DIR = "./out/perf_test/"

# fmt: off
# NLL
_paper_results = {
        ("rmtpp", "stationary-poisson", 0.99), # (Shchur et al, 2020)
        ("rmtpp", "stationary-renewal", 1.01), # (Shchur et al, 2020)
        ("rmtpp", "self-correcting", 0.78), # (Shchur et al, 2020)
        ("rmtpp", "hawkes1", 0.74), # (Shchur et al, 2020)
        ("rmtpp", "hawkes2", 0.69), # (Shchur et al, 2020)
        ("omi-nn", "stationary-poisson", 1.00), # (Shchur et al, 2020)
        ("omi-nn", "stationary-renewal", 0.28), # (Shchur et al, 2020)
        ("omi-nn", "self-correcting", 0.78), # (Shchur et al, 2020)
        ("omi-nn", "hawkes1", 0.55), # (Shchur et al, 2020)
        ("omi-nn", "hawkes2", 0.06), # (Shchur et al, 2020)
        ("omi-exp", "stationary-poisson", 0.99), # (Shchur et al, 2020), as model "Exponential"
        ("omi-exp", "stationary-renewal", 1.01), # (Shchur et al, 2020), as model "Exponential"
        ("omi-exp", "self-correcting", 0.94), # (Shchur et al, 2020), as model "Exponential"
        ("omi-exp", "hawkes1", 0.78), # (Shchur et al, 2020), as model "Exponential"
        ("omi-exp", "hawkes2", 0.69), # (Shchur et al, 2020), as model "Exponential"
        ("shchur-logmix", "stationary-poisson", 0.99), # (Shchur et al, 2020)
        ("shchur-logmix", "stationary-renewal", 0.25), # (Shchur et al, 2020)
        ("shchur-logmix", "self-correcting", 0.78), # (Shchur et al, 2020)
        ("shchur-logmix", "hawkes1", 0.52), # (Shchur et al, 2020)
        ("shchur-logmix", "hawkes2", 0.02), # (Shchur et al, 2020)
}
# fmt: on
paper_results_map = {(k1, k2): v for k1, k2, v in _paper_results}


@pytest.mark.parametrize(
    "model_name", ["rmtpp", "omi-nn", "omi-exp", "shchur-logmix"]
)
@pytest.mark.parametrize(
    "ds_name",
    [
        "stationary-poisson",
        # "nonstationary-poisson",
        "stationary-renewal",
        # "nonstationary-renewal",
        "self-correcting",
        "hawkes1",
        "hawkes2",
    ],
)

@pytest.mark.longrun
def test_rand_process_perf(model_name, ds_name):
    out_dir = kdai._logging.get_outdir(
        OUTPUT_ROOT_DIR, labels=[model_name, ds_name]
    )
    ds_fn = experiments.rand_process_dataset_fns[ds_name]
    eval_mode = "train-loss"
    trainable = experiments.trainable_fns[model_name](ds_fn, eval_mode)
    train_kwargs = experiments.get_train_args(ds_name, model_name)
    kdai.train.train(trainable, out_dir=out_dir, 
                     save_checkpoints=False,
                     **train_kwargs)
    metrics_df = pl.read_csv(out_dir / "metrics.csv")
    best_nll = metrics_df["loss"].min()
    assert best_nll == pytest.approx(
        paper_results_map[(model_name, ds_name)], rel=0.1
    )

