import pytest
import numpy as np
import random as rand
import torch
import torchvision as tv
import kdai.train
from pathlib import Path

# Temp directory is "_tmp_data" in the current directory
TMP_DIR = Path(__file__).parent / "_tmp_data"
TMP_DIR.mkdir(exist_ok=True)


@pytest.fixture
def out_root():
    return "out/kdai-test"


@pytest.fixture
def seed_random():
    rand.seed(123)
    np.random.seed(123)
    torch.manual_seed(123)


@pytest.fixture
def np_rng():
    return np.random.default_rng(123)


@pytest.fixture
def mnist_ds():
    img_transform = tv.transforms.Compose(
        [
            tv.transforms.ToTensor(),
        ]
    )
    train_ds = tv.datasets.MNIST(
        root=TMP_DIR, train=True, download=True, transform=img_transform
    )
    test_ds = tv.datasets.MNIST(
        root=TMP_DIR, train=False, download=True, transform=img_transform
    )
    assert len(train_ds) == 60000 and len(test_ds) == 10000

    # Split off a validation set
    train_size, val_size = 50000, 10000
    train_ds, val_ds = torch.utils.data.random_split(
        train_ds,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(123),
    )
    assert len(train_ds) == train_size and len(val_ds) == val_size

    # Check a few images
    for i in range(10):
        img, label = train_ds[i]
        assert img.shape == (1, 28, 28)
        assert 0 <= label < 10
        assert torch.all(img >= 0) and torch.all(img <= 255)
    return train_ds, val_ds, test_ds


@pytest.fixture
def mnist_ds_mgr(mnist_ds):
    train_ds, val_ds, test_ds = mnist_ds
    # The outputs of random_split() are "Subset"s, which just do reindexing of
    # the original datasets. They don't really allow quick access to training
    # split without iterating through the whole thing. So go in, get the
    # indicies and get the underlying array. A lot of effort just to avoid
    # tourching the val portion.
    orig_train_ds = train_ds.dataset
    train_ds_indices = train_ds.indices
    train_imgs = orig_train_ds.data[train_ds_indices].float()
    assert train_imgs.shape == (50000, 28, 28)
    mean = torch.mean(train_imgs)
    sd = torch.std(train_imgs)
    min = torch.min(train_imgs)
    max = torch.max(train_imgs)
    res = kdai.train.BasicDatasetManager(
        train_ds,
        val_ds,
        test_ds,
        train_ds_attrs={"mean": mean, "sd": sd, "min": min, "max": max},
    )
    return res


@pytest.fixture
def fashion_mnist_ds():
    img_transform = tv.transforms.Compose(
        [
            tv.transforms.ToTensor(),
        ]
    )
    train_ds = tv.datasets.FashionMNIST(
        root=TMP_DIR, train=True, download=True, transform=img_transform
    )
    test_ds = tv.datasets.FashionMNIST(
        root=TMP_DIR, train=False, download=True, transform=img_transform
    )
    assert len(train_ds) == 60000 and len(test_ds) == 10000

    # Split off a validation set
    train_size, val_size = 50000, 10000
    train_ds, val_ds = torch.utils.data.random_split(
        train_ds,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(123),
    )
    assert len(train_ds) == train_size and len(val_ds) == val_size

    # Check a few images
    for i in range(10):
        img, label = train_ds[i]
        assert img.shape == (1, 28, 28)
        assert 0 <= label < 10
        assert torch.all(img >= 0) and torch.all(img <= 255)
    return train_ds, val_ds, test_ds


@pytest.fixture
def fashion_mnist_ds_mgr(fashion_mnist_ds):
    train_ds, val_ds, test_ds = fashion_mnist_ds
    # Extract underlying data for normalization statistics
    orig_train_ds = train_ds.dataset
    train_ds_indices = train_ds.indices
    train_imgs = orig_train_ds.data[train_ds_indices].float()
    assert train_imgs.shape == (50000, 28, 28)
    mean = torch.mean(train_imgs)
    sd = torch.std(train_imgs)
    min = torch.min(train_imgs)
    max = torch.max(train_imgs)
    res = kdai.train.BasicDatasetManager(
        train_ds,
        val_ds,
        test_ds,
        train_ds_attrs={"mean": mean, "sd": sd, "min": min, "max": max},
    )
    return res
