import numpy as np
from scipy.stats import multivariate_normal as gauss
from src.utils.measure import Measure
from typing import Union


def generate_nd_gaussian_measure(
    mu: np.ndarray,
    cov: np.ndarray,
    shape: Union[int, tuple],
    start: tuple = (0, 0),
    end: tuple = (1, 1),
    measure_name: str = "mu",
    space_name: str = "x",
    ndim: int = 2,
) -> Measure:
    """
    Generates a Gaussian measure in n-dimensional space.

    Args:
        mu (np.ndarray): The mean vector of the Gaussian distribution.
        cov (np.ndarray): The covariance matrix of the Gaussian distribution.
        shape (Union[int, tuple]): The shape of the measure. If an integer is provided, the shape will be a tuple with that integer repeated for each dimension.
        start (tuple, optional): The starting point of the measure in each dimension. Defaults to (0, 0).
        end (tuple, optional): The ending point of the measure in each dimension. Defaults to (1, 1).
        measure_name (str, optional): The name of the measure. Defaults to "mu".
        space_name (str, optional): The name of the space. Defaults to "x".
        ndim (int, optional): The number of dimensions. Defaults to 2.

    Returns:
        Measure: The generated Gaussian measure.

    Raises:
        ValueError: If the dimension of the shape does not match the specified number of dimensions.

    """
    if type(shape) is int:
        shape = [shape] * ndim
    if len(shape) != ndim:
        raise ValueError("dimension missmatch")
    coords = [
        s + 0.5 * (e - s) / size + (e - s) * np.arange(size) / size
        for s, e, size in zip(start, end, shape)
    ]
    mesh = np.meshgrid(*coords)
    grid = np.stack(mesh, axis=-1).reshape(-1, ndim)
    measure = gauss(mu, cov).pdf(grid).reshape(shape)
    return Measure(
        data=measure / np.sum(measure),
        start=start,
        end=end,
        measure_name=measure_name,
        space_name=space_name,
    )
