from pathlib import Path

import pytest
import torch

from src.datamodules.distribution_datamodule import (
    SKLearnDataModule,
    TorchDynDataModule,
    TwoDimDataModule,
)


@pytest.mark.parametrize("batch_size", [32, 128])
@pytest.mark.parametrize("train_val_test_split", [400, [1000, 100, 100]])
@pytest.mark.parametrize(
    "datamodule,system",
    [
        (SKLearnDataModule, "scurve"),
        (SKLearnDataModule, "moons"),
        (TorchDynDataModule, "gaussians"),
    ],
)
def test_single_datamodule(batch_size, train_val_test_split, datamodule, system):
    dm = datamodule(
        batch_size=batch_size, train_val_test_split=train_val_test_split, system=system
    )

    assert dm.data_train is not None and dm.data_val is not None and dm.data_test is not None
    assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader()

    num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test)
    assert num_datapoints == 1200

    batch = next(iter(dm.train_dataloader()))
    x = batch
    assert x.dim() == 2
    assert x.shape[0] == batch_size
    assert x.shape[-1] == 2
    assert dm.dim == 2
    assert x.dtype == torch.float32


@pytest.mark.parametrize("batch_size", [32, 128])
@pytest.mark.parametrize("train_val_test_split", [300, [200, 50, 50]])
@pytest.mark.parametrize(
    "datamodule,system",
    [
        (TwoDimDataModule, "moon-8gaussians"),
    ],
)
def test_trajectory_datamodule(batch_size, train_val_test_split, datamodule, system):
    dm = datamodule(
        batch_size=batch_size, train_val_test_split=train_val_test_split, system=system
    )
    # assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader()

    batch = next(iter(dm.train_dataloader()))
    x = batch
    assert len(x) == 2
    for t in range(len(dm.timepoint_data)):
        xt = x[t]
        assert xt.dim() == 2
        assert xt.shape[0] == batch_size
        assert xt.shape[-1] == 2
        assert xt.dtype == torch.float32
    assert dm.dim == 2
