import jax
import jax.numpy as jnp
from jax import jit
from jax.experimental.sparse import BCOO
from jax.experimental import sparse
from typing import Sequence
from functools import partial
from jax import lax


@partial(jax.jit, static_argnames=["scale_factor"])
def sum_pool(x: jax.Array, scale_factor: int):
    window = (scale_factor,) * x.ndim
    return lax.reduce_window(x, 0.0, lax.add, window, window, "valid")


@partial(jax.jit, static_argnames=["scale_factor"])
def min_pool(x: jax.Array, scale_factor: int):
    window = (scale_factor,) * x.ndim
    return lax.reduce_window(
        x, jnp.finfo(jnp.float32).max, lax.min, window, window, "valid"
    )


@jit
def cost_matrix(coords: jax.Array, p: float) -> jax.Array:
    """
    Compute the `p`-ground cost matrix between two measures on coordiantes `coords`.
    """
    # return jnp.linalg.norm(coords[:, None] - coords[None, :], axis=2) ** p
    return cdist(coords, coords) ** p


@partial(jax.jit, static_argnames=["shape"])
def coordiantes(shape: Sequence[int]) -> jax.Array:
    """
    Compute the coordinates of a tensor of shape `shape`.
    """
    ndims = len(shape)
    return (
        jnp.stack(jnp.meshgrid(*[jnp.arange(N) for N in shape], indexing="ij"), axis=0)
        .reshape(ndims, -1)
        .astype(jnp.float32)
        .T
    )


@partial(jax.jit, static_argnames=["shape", "scale_factor"])
def scaled_coordinates(shape: Sequence[int], scale_factor: int) -> jax.Array:
    """
    Compute the coordinates of a tensor of shape `shape` diluted by `scale_factor`.
    """
    ndims = len(shape)
    return (
        jnp.stack(
            jnp.meshgrid(
                *[jnp.arange((scale_factor - 1) / 2, N, scale_factor) for N in shape],
                indexing="ij",
            ),
            axis=0,
        )
        .reshape(ndims, -1)
        .astype(jnp.float32)
        .T
    )


@jit
def cdist(X: jax.Array, Y: jax.Array, metric="euclidean") -> jax.Array:
    """
    Compute distance between each pair of the two collections of inputs.
    """

    # Compute squared Euclidean distances
    XX = jnp.sum(X * X, axis=1, keepdims=True)
    YY = jnp.sum(Y * Y, axis=1)
    XY = jnp.dot(X, Y.T)

    # Broadcasting to compute pairwise distances
    squared_distances = XX + YY - 2 * XY

    # Ensure non-negative values due to numerical precision
    squared_distances = jnp.maximum(squared_distances, 0)

    if metric == "sqeuclidean":
        return squared_distances
    elif metric == "euclidean":
        return jnp.sqrt(squared_distances)
    else:
        raise ValueError(f"Unsupported metric: {metric}")


@jit
def diag(v: jax.Array) -> BCOO:
    return sparse.eye(v.size) * v


def view_as_blocks(arr, block_shape):
    """
    Divide a JAX array into non-overlapping blocks of the specified shape.

    Parameters:
    arr (jax.numpy.ndarray): Input array.
    block_shape (tuple): Shape of each block.

    Returns:
    jax.numpy.ndarray: View of the input array as blocks.
    """
    # Ensure the input array shape is divisible by the block shape
    if any(s % b != 0 for s, b in zip(arr.shape, block_shape)):
        raise ValueError("Array shape must be divisible by block shape.")

    # Compute the new shape and strides
    new_shape = tuple(s // b for s, b in zip(arr.shape, block_shape)) + block_shape
    new_strides = (
        tuple(arr.strides[i] * block_shape[i] for i in range(len(block_shape)))
        + arr.strides
    )

    # Use lax to reshape and stride the array
    arr_reshaped = lax.reshape(arr, new_shape)
    arr_strided = lax.broadcast_in_dim(
        arr_reshaped, new_shape, list(range(len(new_shape)))
    )

    return arr_strided
