# test to verify that task batch call works
from pathlib import Path

import jax
import pytest

from maml.data import sample_task_batch, OmniglotDataset


default_omniglot_path = Path("omniglot_resized/all_images")


def test_sample_task_batch_call():
    key = jax.random.PRNGKey(0)
    *_, key = sample_task_batch(key, 25, 10)


@pytest.mark.skipif(not default_omniglot_path.exists(), reason="Omniglot dataset not found")
@pytest.mark.parametrize("rotation_augmentation", [True, False])
def test_omniglot_dataset(rotation_augmentation):
    dataset = OmniglotDataset(rotation_augmentation=rotation_augmentation)
    dataset.load_characters_labels()
    key = jax.random.PRNGKey(0)
    tasks_inputs_train, tasks_outputs_train, _, _ = dataset.sample_tasks(key, 25)
    assert tasks_inputs_train.shape == (25, 50, 28, 28, 1)
    assert tasks_outputs_train.shape == (25, 50, 5)
