import jax
import jax.numpy as jnp
import jax.random as jrnd

ArrayPair = tuple[jax.Array, jax.Array]


def square_boundary(key, num_samples: int, dim: int):
    """
    Samples points uniformly on the boundary of a unit hypercube in `dim` dimensions.

    Args:
        key: JAX PRNGKey for random number generation.
        num_samples: Number of boundary points to sample.
        dim: Dimension of the hypercube.

    Returns:
        inputs: Array of shape (num_samples, dim) containing boundary points.
    """

    key_inputs, key_dim, key_side = jrnd.split(key, 3)
    inputs = jrnd.uniform(key_inputs, (num_samples, dim))
    dimensions = jrnd.randint(key_dim, (num_samples,), 0, dim)
    sides = jrnd.randint(key_side, (num_samples,), 0, 2)
    inputs = inputs.at[jnp.arange(num_samples), dimensions].set(
        sides.astype(inputs.dtype)
    )
    return inputs


def square_interior(key, num_samples: int, dim: int):
    """
    Samples points uniformly inside a unit hypercube in `dim` dimensions.

    Args:
        key: JAX PRNGKey for random number generation.
        num_samples: Number of interior points to sample.
        dim: Dimension of the hypercube.

    Returns:
        inputs: Array of shape (num_samples, dim) containing interior points.
    """
    inputs = jrnd.uniform(key, (num_samples, dim))
    return inputs


def poisson_boundary(key, num_samples: int, dim: int):
    """
    Samples boundary points and computes boundary values for the Poisson problem.

    Args:
        key: JAX PRNGKey for random number generation.
        num_samples: Number of boundary points to sample.
        dim: Dimension of the hypercube.

    Returns:
        inputs: Array of shape (num_samples, dim) containing boundary points.
        outputs: Array of shape (num_samples,) containing boundary values.
    """
    inputs = square_boundary(key, num_samples, dim)
    outputs = jnp.prod(jnp.sin(inputs * jnp.pi), axis=1, keepdims=True)
    return inputs, outputs


def poisson_interior(key, num_samples: int, dim: int):
    """
    Samples interior points and computes the right-hand side Laplace for the Poisson problem.

    Args:
        key: JAX PRNGKey for random number generation.
        num_samples: Number of interior points to sample.
        dim: Dimension of the hypercube.

    Returns:
        inputs: Array of shape (num_samples, dim) containing interior points.
        laplace: Array of shape (num_samples,) containing Laplacian values.
    """
    inputs = square_interior(key, num_samples, dim)
    sin_xy = jnp.prod(jnp.sin(inputs * jnp.pi), axis=-1, keepdims=True)
    laplace = dim * (jnp.pi**2) * sin_xy
    return inputs, laplace


def poisson_test_data(key, num_samples: int, dim: int):
    """
    Samples test points and computes the exact solution for the Poisson problem.

    Args:
        key: JAX PRNGKey for random number generation.
        num_samples: Number of test points to sample.
        dim: Dimension of the hypercube.

    Returns:
        inputs: Array of shape (num_samples, dim) containing test points.
        outputs: Array of shape (num_samples,) containing exact solution values.
    """
    inputs = square_interior(key, num_samples, dim)
    outputs = jnp.prod(jnp.sin(inputs * jnp.pi), axis=-1, keepdims=True)
    return inputs, outputs


def create_poisson_data(
    key, dim: int, num_in: int, num_bound: int, num_test: int
) -> tuple[jax.Array, tuple[ArrayPair, ArrayPair], ArrayPair]:
    key_bound, key_in, key_test, key_rest = jax.random.split(key, 4)
    x_bound, y_bound = poisson_boundary(key_bound, num_bound, dim)
    x_in, y_in = poisson_interior(key_in, num_in, dim)
    x_test, y_test = poisson_test_data(key_test, num_test, dim)
    data = (x_bound, y_bound), (x_in, y_in)
    test_data = (x_test, y_test)
    return key_rest, data, test_data
