import time
from functools import partial
import jax
import jax.numpy as jnp
import cupy
import numpy as np

def make_io_stream():
    return cupy.cuda.Stream(non_blocking=True)

def is_gpu(x):
    return list(x.devices())[0].platform == 'gpu'

def is_cpu(x):
    return list(x.devices())[0].platform == 'cpu'

def _dlpack_gpu2cpu(x, out=None, stream=None):
    if isinstance(x, int) or x.dtype == jax.float0:
        return x

    if isinstance(x, np.ndarray) or is_cpu(x):
        return x

    assert stream is not None

    x = jax.dlpack.to_dlpack(x, copy=False)
    x = cupy.from_dlpack(x)

    if out is None:
        out = cupy.cuda.alloc_pinned_memory(x.nbytes)
        out = np.frombuffer(out, x.dtype, x.size).reshape(x.shape)

    assert out.shape == x.shape
    assert out.dtype == x.dtype

    x = cupy.asnumpy(x, order='A', blocking=False, out=out, stream=stream)
    return x

def _dlpack_cpu2gpu(x, stream):
    raise NotImplementedError
    if isinstance(x, int) or x.dtype == jax.float0:
        return x

    if not isinstance(x, np.ndarray) and is_gpu(x):
        return x

    with stream:
        x = cupy.asarray(x, order='K', blocking=False)
        x = x.toDlpack()
        x = jax.dlpack.from_dlpack(x, copy=False)

    return x

def _to_cupy(x, buffer=None):
    if isinstance(x, int) or x.dtype == jax.float0:
        return x

    if not isinstance(x, np.ndarray) and (isinstance(x, cupy.ndarray) or is_gpu(x)):
        return x

    if buffer is None:
        x = cupy.asarray(x, order='K', blocking=False)
    else:
        buffer.set(x)
        x = buffer

    return x

def _to_dlpack(x):
    if not isinstance(x, cupy.ndarray):
        return x

    x = x.toDlpack()
    return x

def _to_jax(x):
    if isinstance(x, int) or (hasattr(x, 'dtype') and x.dtype == jax.float0):
        return x

    if hasattr(x, 'devices') and is_gpu(x):
        return x

    return jax.dlpack.from_dlpack(x, copy=False)

def dlpack_cpu2gpu(x_numpy, stream=None, replace_buffers=None): # 38seconds
    raise NotImplementedError
    # check if any leaves are numpy
    def is_numpy(x):
        return isinstance(x, np.ndarray)

    is_any_numpy = jax.tree_util.tree_reduce(lambda x, y: x or y, jax.tree_util.tree_map(is_numpy, x_numpy))
    if not is_any_numpy:
        return x_numpy, lambda: (replace_buffers, None, x_numpy)

    with stream:
        if replace_buffers is None:
            x_cupy = jax.tree_util.tree_map(_to_cupy, x_numpy)
        else:
            x_cupy = jax.tree_util.tree_map(_to_cupy, x_numpy, replace_buffers)

        x_dlpack = jax.tree_util.tree_map(_to_dlpack, x_cupy)

    x_jax = jax.tree_util.tree_map(_to_jax, x_dlpack)
    # x_jax = jax.device_put(x_numpy, jax.devices('gpu')[0])

    def block_fn():
        # stream.synchronize()
        # return None, None, x_jax 
        return x_cupy, x_dlpack, x_jax

    return x_jax, block_fn # x_jax, block_fn

def dlpack_gpu2cpu(x, stream, replace_buffers=None):
    map_fn = partial(_dlpack_gpu2cpu, stream=stream)

    if replace_buffers is None:
        cpu_x = jax.tree_util.tree_map(map_fn, x)
    else:
        cpu_x = jax.tree_util.tree_map(map_fn, x, replace_buffers)

    def block_fn():
        stream.synchronize()
        return x, cpu_x

    return cpu_x, block_fn
    # return jax.device_put(x, jax.devices('cpu')[0])
