import pytest
import kdai
import kdai._logging
import kdai.train
import kdai.datasets
import kdtpp.experiments as experiments
import kdtpp.datasets as ds
import polars as pl
import numpy as np
import kdtpp.cyclicgroup as cg
from pathlib import Path

OUTPUT_ROOT_DIR = "./out/perf_test/"


def gen_fixed_pattern(model_in_len):
    n_events = int(10e4)
    # Merge two increasing sequences.
    n1 = 110
    n2 = 27
    pattern_1 = np.arange(n_events) * n1
    pattern_2 = np.arange(n_events) * n2
    pattern = np.sort(np.concatenate([pattern_1, pattern_2]))[:n_events]
    log_probs = np.zeros_like(pattern, dtype=float)
    return ds.RandProcessDatasets(
        pattern, log_probs, t_max=None, model_in_len=model_in_len
    )

def gen_cyclic(model_in_len):
    ts, _ = cg.gen(p=1021, n=int(10e4), start_pos=[41], vel=[10])
    log_probs = np.zeros_like(ts, dtype=float)
    return ds.RandProcessDatasets(
        ts, log_probs, t_max=None, model_in_len=model_in_len
    )

def gen_cyclic22(model_in_len):
    ts1, _ = cg.gen(p=1021, n=int(10e4), start_pos=[41], vel=[10])
    ts2, _ = cg.gen(p=1021, n=int(10e4), start_pos=[17], vel=[56])
    ts = np.concatenate([np.zeros([2, 1]), np.stack([ts1, ts2])], axis=1)
    dts = np.diff(ts)
    split_ratio = (7, 2, 1)
    train_seqs, val_seqs, test_seqs = kdai.datasets.split(dts, split_ratio)
    return ds.EventSeqListDatasets(
        train_seqs, val_seqs, test_seqs, model_in_len=model_in_len
    )

@pytest.fixture(
    params=["easy_1dim_1obj_1024event.npz", "easier_1dim_1obj_1024event.npz"],
    ids=["easy", "easier"])
def gen_cyclic_fn(resource_dir, request):
    path = resource_dir / request.param
    ts = np.load(path)["ts"]
    n_seq, seq_len = ts.shape
    ts = np.concatenate([np.zeros([n_seq, 1]), ts], axis=1)
    dts = np.diff(ts)
    split_ratio = (8, 1, 1) 
    train_seqs, val_seqs, test_seqs = kdai.datasets.split(dts, split_ratio)

    def to_data_mgr(model_in_len):
        res = ds.EventSeqListDatasets(
            train_seqs,
            val_seqs,
            test_seqs,
            model_in_len=model_in_len,
            full_y=True,
        )
        return res
    return to_data_mgr


@pytest.mark.longrun
def test_transformer_perf(gen_cyclic_fn):
    out_dir = kdai._logging.get_outdir(OUTPUT_ROOT_DIR, labels=["transformer"])
    #trainable = experiments.trainable_fns["ours-discrete"](gen_fixed_pattern)
    trainable = experiments.trainable_fns["ours-discrete"](gen_cyclic_fn)
    train_kwargs = experiments.get_train_args(
        ds_name=None, trainable_name="ours-discrete"
    )
    train_kwargs["n_epochs"] = 20
    samples_per_epoch = 1**20
    kdai.train.train(
        trainable,
        out_dir=out_dir,
        save_checkpoints=False,
        samples_per_epoch=samples_per_epoch,
        **train_kwargs
    )
    metrics_df = pl.read_csv(out_dir / "metrics.csv")
    best_nll = metrics_df["mean_nll"].min()
    assert (
        best_nll < 0.01
    ), "Model should be practically perfect on this dataset."
