"""Test the playback dataloader
"""

from pathlib import Path

from konductor.data._pytorch.dataloader import DataloaderV1Config
from konductor.init import ExperimentInitConfig

import src  # load things into registry
from src.dataset.chasing_targets import (
    GymPredictConfig,
    OccupancyParams,
    SampleRange,
    Split,
)


def test_config_loading(tmp_path):
    cfg_file = Path(__file__).parent.parent / "base_cfg.yml"
    config = ExperimentInitConfig.from_config(tmp_path, cfg_file)
    assert config.model[0].type == "motion-perceiver"


def test_targets_loading():
    loader = DataloaderV1Config(workers=4, batch_size=8)
    default = GymPredictConfig(train_loader=loader, val_loader=loader)
    loader = default.get_instance(Split.TRAIN)

    data = next(iter(loader))
    assert "agents_occ" not in data


def test_occupancy_loading():
    loader = DataloaderV1Config(workers=4, batch_size=8)
    occ_cfg = OccupancyParams(256, [0, 5, 10, 50], 3, SampleRange(0, 49))
    dataset = GymPredictConfig(
        occupancy=occ_cfg, train_loader=loader, val_loader=loader
    )
    loader = dataset.get_instance(Split.TRAIN)

    data = next(iter(loader))
    assert "agents_occ" in data
