import jax
from jax.sharding import PartitionSpec as P, NamedSharding
from jax.tree_util import tree_map

import collections
import itertools
import tensorflow as tf


def prepare_tf_data(xs):

    #local_device_count = jax.local_device_count()

    def _prepare(x):
        return x._numpy()
        

    return jax.tree_util.tree_map(_prepare, xs)


def prefetch_to_mesh(iterator, size: int):
  
    mesh=jax.make_mesh((jax.local_device_count(),),('batch',))
    
    queue = collections.deque()

    def _prefetch(xs):
        return jax.device_put(xs, NamedSharding(mesh, P('batch')))

    def enqueue(n):  # Enqueues *up to* `n` elements from the iterator.
        for data in itertools.islice(iterator, n):
            queue.append(tree_map(_prefetch, data))

    enqueue(size)  # Fill up the buffer.
    while queue:
        yield queue.popleft()
        enqueue(1)



def create_dataloaders(batch_size, test_batch_size, train_dataset, test_dataset):
    train_dataset = tf.data.Dataset.from_tensor_slices(train_dataset)
    test_dataset = tf.data.Dataset.from_tensor_slices(test_dataset)

    shuffle_buffer_size =min(32000,10000) 
    
    train_dataset = (
        train_dataset.cache()
        .shuffle(shuffle_buffer_size, reshuffle_each_iteration=True, seed=42)
        .repeat()
        .batch(batch_size)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )
    
    test_dataset = (
                test_dataset.cache()
                .repeat()
                .batch(test_batch_size)
                .prefetch(tf.data.experimental.AUTOTUNE)
    )

    train_iter = map(prepare_tf_data, train_dataset)
    test_iter = map(prepare_tf_data, test_dataset)

    train_iter = prefetch_to_mesh(train_iter, 2)
    test_iter = prefetch_to_mesh(test_iter, 2)

    return train_iter, test_iter