import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.signal import convolve2d
from time import time
from skimage.filters import gaussian


from utils import flatten, unflatten

def gaussian_2d_kernel(size: int, sigma: float):
    r"""Generate a 2D Gaussian kernel.

    Args:
        size (int): The size of the kernel (must be odd).
        sigma (float): Standard deviation of the Gaussian.

    Returns:
        jnp.ndarray: A (size x size) Gaussian kernel.
    """
    assert size % 2 == 1, "Kernel size must be odd"
    
    ax = jnp.arange(-(size // 2), size // 2 + 1)
    xx, yy = jnp.meshgrid(ax, ax)
    kernel = jnp.exp(-(xx**2 + yy**2) / (2.0 * sigma**2))
    kernel /= kernel.sum()
    return kernel


def gaussian_blur(image: jnp.ndarray, size: int, sigma: float):
    kernel = gaussian_2d_kernel(size, sigma)
    return convolve2d(image, kernel, mode='same', boundary='fill', fillvalue=0)

def gaussian_jax(
    image,
    sigma=1.0,
    ):
    r"""
    Applies a Gaussian filter to the input image using JAX.
    Similar to `scipy.ndimage.gaussian_filter` when run with the following kwargs:
    mode = 'nearest', cval = 0, preserve_range = False, channel_axis = -1, truncate = inf
    , but implemented
    in JAX for compatibility with JAX-based workflows.
    """

    r = gaussian_blur(image[:,:,0], size=33, sigma=sigma)
    g = gaussian_blur(image[:,:,1], size=33, sigma=sigma)
    b = gaussian_blur(image[:,:,2], size=33, sigma=sigma)

    return jnp.stack([r, g, b], axis=-1)



def find_kernel():
    from skimage.filters import gaussian
    img = np.zeros(shape = (32, 32, 1))
    x, y = 10, 10
    img[x, y, 0] = 1
    img_corrupted = gaussian(img, sigma=1, channel_axis=-1, preserve_range=True)
    kernel_size = 11
    half_size = kernel_size // 2
    # print(img_corrupted[x-half_size:x+half_size, y-half_size:y+half_size, 0])
    # print(img_corrupted.sum())
    return img_corrupted[x-half_size:x+half_size + 1, y-half_size:y+half_size + 1, 0]


kernel = gaussian_2d_kernel(size=9, sigma=1)
kernel = jax.device_put(kernel)

@jax.vmap
def corrupt_func(x, kernel = kernel):
    r"""x.shape = (batch_size, 32 * 32 * 3) if wrapped with vmap,"""
    x = unflatten(x, 32, 32)

    # for i in range(32):
    #     for j in range(32):
    #         for c in range(3):
    #             for k_i in range(-4, 5):
    #                 for k_j in range(-4, 5):
    #                     if 0 <= i + k_i < 32 and 0 <= j + k_j < 32:
    #                         x[i, j, c] += kernel[k_i + 4, k_j + 4] * x[i + k_i, j + k_j, c]

    x = jnp.stack([
        convolve2d(x[:, :, 0], kernel, mode='same', boundary='fill', fillvalue=0),
        convolve2d(x[:, :, 1], kernel, mode='same', boundary='fill', fillvalue=0),
        convolve2d(x[:, :, 2], kernel, mode='same', boundary='fill', fillvalue=0)
    ], axis=-1)

    x = flatten(x)
    return x

if __name__ == "__main__":
    key = jax.random.PRNGKey(0)
    x = jax.random.normal(key, shape=(2 << 10, 32 * 32 * 3))
    x = jax.device_put(x)


    begin = time()
    for _ in range(100):
        x_corrupted = jax.jit(corrupt_func, backend='cpu')(x)
    end = time()
    print(f"CPU: {end - begin:.4f} seconds")

    begin = time()
    for _ in range(100):
        x_corrupted = jax.jit(corrupt_func, backend='gpu')(x)
    end = time()
    print(f"GPU: {end - begin:.4f} seconds")

    # a = jnp.zeros(shape = (4, 4))
    # a[0, 0] = 1
    # print(a)



# if __name__ == "__main__":
#     import jax
#     import numpy as np
#     import jax.numpy as jnp

#     pulse = jax.device_put(np.random.rand(8000000))

#     def f(pulse):
#         sigTx = jnp.fft.fft(pulse)
#         return sigTx

#     t0 = time.time()
#     a = jax.jit(f, backend='gpu')(pulse)
#     t1 = time.time()
#     print('time cost on gpu:',t1-t0)


#     t0 = time.time()
#     a = jax.jit(f, backend='cpu')(pulse)
#     t1 = time.time()
#     print('time cost on cpu:',t1-t0)