import tensorflow_datasets as tfds
import jax.numpy as jnp
import numpy as np
import tensorflow as tf

def download_and_prepare_mnist(limit=10):
    """
    Downloads the MNIST dataset using TensorFlow Datasets and returns a limited number of samples
    from the training and test data as JAX numpy arrays.

    Args:
        limit (int): Number of samples to read for testing. Default is 10.

    Returns:
        train_images (jax.numpy.array): Training images as a JAX numpy array.
        train_labels (jax.numpy.array): Training labels as a JAX numpy array.
        test_images (jax.numpy.array): Test images as a JAX numpy array.
        test_labels (jax.numpy.array): Test labels as a JAX numpy array.
    """
    # Load MNIST dataset from TensorFlow Datasets
    dataset, info = tfds.load('mnist', split=['train', 'test'], as_supervised=True, with_info=True)
    train_data, test_data = dataset

    def _process_tfds_dataset(tf_dataset, limit):
        """
        Converts a TensorFlow Dataset into JAX numpy arrays, limited to a specific number of samples.

        Args:
            tf_dataset (tf.data.Dataset): TensorFlow Dataset to convert.
            limit (int): Number of samples to process.

        Returns:
            images (jax.numpy.array): Images as JAX numpy arrays.
            labels (jax.numpy.array): Labels as JAX numpy arrays.
        """
        images, labels = [], []
        for i, (image, label) in enumerate(tfds.as_numpy(tf_dataset)):
            if i >= limit:
                break
            images.append(image)
            labels.append(label)
        
        # Convert lists to JAX numpy arrays
        images = jnp.array(images)
        labels = jnp.array(labels)
        return images, labels

    # Process training and test datasets
    train_images, train_labels = _process_tfds_dataset(train_data, limit)
    test_images, test_labels = _process_tfds_dataset(test_data, limit)

    return train_images, train_labels, test_images, test_labels

import tensorflow_datasets as tfds
import jax.numpy as jnp

def download_and_filter_mnist(limit=None):
    """
    Downloads the MNIST dataset using TensorFlow Datasets, filters data points with labels 0 and 1,
    and returns the filtered dataset as JAX numpy arrays.

    Args:
        limit (int, optional): Number of samples to read for testing. If None, uses the full dataset.

    Returns:
        train_images (jax.numpy.array): Training images as a JAX numpy array (filtered for labels 0 and 1).
        train_labels (jax.numpy.array): Training labels as a JAX numpy array (filtered for labels 0 and 1).
        test_images (jax.numpy.array): Test images as a JAX numpy array (filtered for labels 0 and 1).
        test_labels (jax.numpy.array): Test labels as a JAX numpy array (filtered for labels 0 and 1).
    """
    # Load MNIST dataset from TensorFlow Datasets
    dataset, info = tfds.load('mnist', split=['train', 'test'], as_supervised=True, with_info=True)
    train_data, test_data = dataset

    def _filter_and_process_tfds_dataset(tf_dataset, limit=None):
        """
        Filters a TensorFlow Dataset to include only labels 0 and 1, and converts it to JAX numpy arrays.

        Args:
            tf_dataset (tf.data.Dataset): TensorFlow Dataset to filter and convert.
            limit (int, optional): Number of samples to process. If None, processes all samples.

        Returns:
            images (jax.numpy.array): Filtered images as JAX numpy arrays.
            labels (jax.numpy.array): Filtered labels as JAX numpy arrays.
        """
        images, labels = [], []
        count = 0
        for image, label in tfds.as_numpy(tf_dataset):
            if label in [0, 1]:
                images.append(image)
                labels.append(label)
                count += 1
                if limit is not None and count >= limit:
                    break

        # Convert lists to JAX numpy arrays
        images = jnp.array(images)
        labels = jnp.array(labels)
        return images, labels

    # Filter and process training and test datasets
    train_images, train_labels = _filter_and_process_tfds_dataset(train_data, limit)
    test_images, test_labels = _filter_and_process_tfds_dataset(test_data, limit)

    return train_images, train_labels, test_images, test_labels

def load_binary_cifar(classes_to_include=(0, 1)):
    """
    Loads and filters the CIFAR-10 dataset to include only the specified classes.

    Args:
        classes_to_include (tuple): Tuple of two class indices to include (e.g., (0, 1)).

    Returns:
        (numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray):
            Train images, train labels, test images, test labels, with binary labels (0 or 1).
    """
    # Load CIFAR-10 dataset
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

    # Flatten labels
    y_train = y_train.flatten()
    y_test = y_test.flatten()

    # Filter the dataset for the specified classes
    train_mask = np.isin(y_train, classes_to_include)
    test_mask = np.isin(y_test, classes_to_include)

    x_train_binary = x_train[train_mask]
    y_train_binary = y_train[train_mask]
    x_test_binary = x_test[test_mask]
    y_test_binary = y_test[test_mask]

    # Convert labels to binary (0 or 1)
    label_mapping = {classes_to_include[0]: 0, classes_to_include[1]: 1}
    y_train_binary = np.vectorize(label_mapping.get)(y_train_binary)
    y_test_binary = np.vectorize(label_mapping.get)(y_test_binary)

    return x_train_binary, y_train_binary, x_test_binary, y_test_binary

# Example usage
if __name__ == "__main__":
    train_images, train_labels, test_images, test_labels = download_and_prepare_mnist()
    print(f"Train Images Shape: {train_images.shape}")
    print(f"Train Labels Shape: {train_labels.shape}")
    print(f"Test Images Shape: {test_images.shape}")
    print(f"Test Labels Shape: {test_labels.shape}")