from functools import partial
from typing import List
import einops
import flax

import jax
import numpy as np
import tensorflow as tf


def get_devices(device_list: List[jax.Device], devices: List[int]) -> List[jax.Device]:
    filtered_device_list = [device for device in device_list if device.id in devices]
    for device in filtered_device_list:
        if device not in device_list:
            raise ValueError(f"Device {device} not in original device list")
    return filtered_device_list


def split_to_devices(x, num_devices: int):
    def split_tensor(x):
        if isinstance(x, tf.Tensor):
            x = x._numpy()
        return einops.rearrange(x, "(d b) ... -> d b ...", d=num_devices)

    return jax.tree_map(split_tensor, x)


def split_and_prefetch(train_dataset: tf.data.Dataset, device_list: List[jax.Device]):
    return flax.jax_utils.prefetch_to_device(
        map(partial(split_to_devices, num_devices=len(device_list)), train_dataset),
        3,
        devices=device_list,
    )
