from typing import Union

import numpy as np


def ellipsoid_volume_element(principal_radii: np.ndarray, coordinates: np.ndarray):
    """Formula to compute the volume element of a 2 or 3-ellipsoid on a point. The formula has been computed
    programmatically using the tools in 'symbolic_computations.py'.

    Args:
        principal_radii: The principal semi-axes of one or more ellipsoids in an array of shape (n_points dim + 1).
        coordinates: The polar coordinates of the points to compute the volume element on in array of
            shape (n_points dim).

    Returns:
        A vector of shape (n_points) with the volume element per point.
    """
    n = coordinates.shape[-1]

    if n == 2:
        phi_1, phi_2 = coordinates[..., 0], coordinates[..., 1]
        a_1, a_2, a_3 = principal_radii[..., 0], principal_radii[..., 1], principal_radii[..., 2]
        element = np.sqrt(
            a_1**2*a_2**2*np.sin(phi_1)**4*np.sin(phi_2)**2
            - a_1**2*a_3**2*np.sin(phi_1)**4*np.sin(phi_2)**2
            + a_1**2*a_3**2*np.sin(phi_1)**4
            - a_2**2*a_3**2*np.sin(phi_1)**4
            + a_2**2*a_3**2*np.sin(phi_1)**2
        )
    elif n == 3:
        phi_1, phi_2, phi_3 = coordinates[..., 0], coordinates[..., 1], coordinates[..., 2]
        a_1, a_2, a_3, a_4 = (
            principal_radii[..., 0], principal_radii[..., 1], principal_radii[..., 2], principal_radii[..., 3])
        element = np.sqrt(
            (
                a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2
                - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2
                + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2
                - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2
                + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2
                - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2
                + a_2**2*a_3**2*a_4**2
            )
            * np.sin(phi_1)**4*np.sin(phi_2)**2
        )
    else:
        raise ValueError(f"Only volumes of ellipsoids of dimensions 2 and 3 implemented.")

    return np.array(element).astype(np.float64)


def ellipsoid_scalar_curvature(principal_radii: np.ndarray, coordinates: np.ndarray):
    """Formula to compute the scalar curvature of a 2 or 3-ellipsoid on a point. The formula has been computed
    programmatically using the tools in 'symbolic_computations.py'.

    Args:
        principal_radii: The principal semi-axes of the ellipsoid.
        coordinates: The polar coordinates of the point to compute the volume element on.

    Returns:
        The scalar curvature as a float.
    """
    n = coordinates.shape[-1]

    if n == 2:
        phi_1, phi_2 = coordinates[..., 0], coordinates[..., 1]
        a_1, a_2, a_3 = principal_radii[..., 0], principal_radii[..., 1], principal_radii[..., 2]
        curvature = -0.015625*a_1**2*a_2**2*a_3**2*(a_2**2 - a_3**2)**2*(np.cos(2*phi_1 - 2*phi_2) - np.cos(2*phi_1 + 2*phi_2))**2/((a_1**2*a_2**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*np.sin(phi_1)**2 - a_2**2*a_3**2*np.sin(phi_1)**2 + a_2**2*a_3**2)**3*np.sin(phi_1)**2) + 1.0*a_1**2*a_2**2*a_3**2*(a_1**2*np.sin(phi_1)**2 + a_2**2*np.cos(phi_1)**2*np.cos(phi_2)**2 + a_3**2*np.sin(phi_2)**2*np.cos(phi_1)**2)*(a_2**2*np.sin(phi_2)**2 - a_3**2*np.sin(phi_2)**2 + a_3**2)*np.sin(phi_1)**2/((a_1**2*a_2**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*np.sin(phi_1)**2 - a_2**2*a_3**2*np.sin(phi_1)**2 + a_2**2*a_3**2)**2*(a_1**2*a_2**2*np.sin(phi_1)**4*np.sin(phi_2)**2 - a_1**2*a_3**2*np.sin(phi_1)**4*np.sin(phi_2)**2 + a_1**2*a_3**2*np.sin(phi_1)**4 - a_2**2*a_3**2*np.sin(phi_1)**4 + a_2**2*a_3**2*np.sin(phi_1)**2)) + (a_2**2 - a_3**2)**2*(np.cos(2*phi_1 - 2*phi_2) - np.cos(2*phi_1 + 2*phi_2))**2*(-1.0*(a_2**2*np.sin(phi_2)**2 + a_3**2*np.cos(phi_2)**2)*(a_1**2*a_2**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*np.sin(phi_1)**2 - a_2**2*a_3**2*np.sin(phi_1)**2 + a_2**2*a_3**2)*np.sin(phi_1)*np.cos(phi_1) + (a_1**2*(a_2**2 - a_3**2)**2*(0.75*np.sin(phi_1) - 0.25*np.sin(3*phi_1))*np.sin(phi_2)**2*np.cos(phi_2)**2 + 0.5*a_2**2*a_3**2*(-a_1**2 + a_2**2*np.cos(phi_2)**2 + a_3**2*np.sin(phi_2)**2)*np.sin(2*phi_1)*np.cos(phi_1))*np.sin(phi_1)*np.tan(phi_1) + 1.0*(-a_2**2*np.sin(phi_1)**2 + a_2**2*np.sin(phi_2)**2 - a_3**2*np.sin(phi_1)**2 - a_3**2*np.sin(phi_2)**2 + a_3**2)*(a_1**2*a_2**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*np.sin(phi_1)**2 - a_2**2*a_3**2*np.sin(phi_1)**2 + a_2**2*a_3**2)*np.tan(phi_1))/(64*(a_1**2*a_2**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*np.sin(phi_1)**2 - a_2**2*a_3**2*np.sin(phi_1)**2 + a_2**2*a_3**2)**3*np.sin(phi_1)**4*np.tan(phi_1)) + (a_1**2*np.sin(phi_1)**2 + a_2**2*np.cos(phi_1)**2*np.cos(phi_2)**2 + a_3**2*np.sin(phi_2)**2*np.cos(phi_1)**2)*(a_2**2*np.sin(phi_2)**2 - a_3**2*np.sin(phi_2)**2 + a_3**2)*(1.0*(a_2**2*np.sin(phi_2)**2 + a_3**2*np.cos(phi_2)**2)*(a_1**2*a_2**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*np.sin(phi_1)**2 - a_2**2*a_3**2*np.sin(phi_1)**2 + a_2**2*a_3**2)*np.sin(phi_1)*np.cos(phi_1) + (a_1**2*(a_2**2 - a_3**2)**2*(-0.75*np.sin(phi_1) + 0.25*np.sin(3*phi_1))*np.sin(phi_2)**2*np.cos(phi_2)**2 - 0.5*a_2**2*a_3**2*(-a_1**2 + a_2**2*np.cos(phi_2)**2 + a_3**2*np.sin(phi_2)**2)*np.sin(2*phi_1)*np.cos(phi_1))*np.sin(phi_1)*np.tan(phi_1) + 1.0*(a_2**2*np.sin(phi_1)**2 - a_2**2*np.sin(phi_2)**2 + a_3**2*np.sin(phi_1)**2 + a_3**2*np.sin(phi_2)**2 - a_3**2)*(a_1**2*a_2**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*np.sin(phi_1)**2 - a_2**2*a_3**2*np.sin(phi_1)**2 + a_2**2*a_3**2)*np.tan(phi_1))/((a_1**2*a_2**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*np.sin(phi_1)**2 - a_2**2*a_3**2*np.sin(phi_1)**2 + a_2**2*a_3**2)**2*(a_1**2*a_2**2*np.sin(phi_1)**4*np.sin(phi_2)**2 - a_1**2*a_3**2*np.sin(phi_1)**4*np.sin(phi_2)**2 + a_1**2*a_3**2*np.sin(phi_1)**4 - a_2**2*a_3**2*np.sin(phi_1)**4 + a_2**2*a_3**2*np.sin(phi_1)**2)*np.tan(phi_1))
    elif n == 3:
        phi_1, phi_2, phi_3 = coordinates[..., 0], coordinates[..., 1], coordinates[..., 2]
        a_1, a_2, a_3, a_4 = (
            principal_radii[..., 0], principal_radii[..., 1], principal_radii[..., 2], principal_radii[..., 3])
        curvature = -2.0*a_1**2*a_2**6*a_3**2*a_4**2*(a_3**2 - a_4**2)**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2*np.cos(phi_1)**2*np.cos(phi_3)**2/((a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)*(a_1**2*a_2**2*a_3**2*np.sin(phi_1)**3*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**3*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**3*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**3*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**3 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**3 + a_2**2*a_3**2*a_4**2*np.sin(phi_1))**2) - 2.0*a_1**2*a_2**2*a_3**2*a_4**2*(a_2**2*a_3**2*np.sin(phi_3)**2 - a_2**2*a_4**2*np.sin(phi_3)**2 + a_2**2*a_4**2 - a_3**2*a_4**2)**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.cos(phi_2)**2/((a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)**3*np.tan(phi_1)**2) - 2.0*a_1**2*a_2**2*a_3**2*a_4**2*(a_1**2*a_3**2*np.sin(phi_1)**2 - a_1**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*np.sin(phi_1)**2 + a_2**2*a_3**2 + a_2**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_4**2)**2*np.sin(phi_2)**2*np.sin(phi_3)**2*np.cos(phi_3)**2/((a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)**3*np.tan(phi_2)**2) + (1.0*a_1**2*a_2**2*a_3**2*a_4**2*(a_1**2*a_3**2*np.cos(phi_1)**2*np.cos(phi_3)**2 - a_1**2*a_3**2*np.cos(phi_1)**2 - a_1**2*a_3**2*np.cos(phi_3)**2 + a_1**2*a_3**2 - a_1**2*a_4**2*np.cos(phi_1)**2*np.cos(phi_3)**2 + a_1**2*a_4**2*np.cos(phi_3)**2 - a_2**2*a_3**2*np.cos(phi_1)**2*np.cos(phi_2)**2*np.cos(phi_3)**2 + a_2**2*a_3**2*np.cos(phi_1)**2*np.cos(phi_2)**2 + a_2**2*a_4**2*np.cos(phi_1)**2*np.cos(phi_2)**2*np.cos(phi_3)**2 - a_3**2*a_4**2*np.cos(phi_1)**2*np.cos(phi_2)**2 + a_3**2*a_4**2*np.cos(phi_1)**2)/(a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)**2 + 1.0*a_1**2*a_2**2*a_3**2*a_4**2*(a_1**2*a_2**2 + a_1**2*a_3**2*np.sin(phi_3)**2 - a_1**2*a_3**2 - a_1**2*a_3**2*np.sin(phi_3)**2/np.sin(phi_2)**2 + a_1**2*a_3**2/np.sin(phi_2)**2 - a_1**2*a_4**2*np.sin(phi_3)**2 + a_1**2*a_4**2*np.sin(phi_3)**2/np.sin(phi_2)**2 + a_2**2*a_3**2*np.sin(phi_3)**2/np.sin(phi_2)**2 - a_2**2*a_3**2/np.sin(phi_2)**2 - a_2**2*a_3**2*np.sin(phi_3)**2/(np.sin(phi_1)**2*np.sin(phi_2)**2) + a_2**2*a_3**2/(np.sin(phi_1)**2*np.sin(phi_2)**2) - a_2**2*a_4**2*np.sin(phi_3)**2/np.sin(phi_2)**2 + a_2**2*a_4**2*np.sin(phi_3)**2/(np.sin(phi_1)**2*np.sin(phi_2)**2))*np.sin(phi_1)**2*np.sin(phi_2)**2/(a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)**2)*(a_2**2*a_3**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_2**2*a_4**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_2**2*a_4**2*np.sin(phi_2)**2 - a_3**2*a_4**2*np.sin(phi_2)**2 + a_3**2*a_4**2)/(a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2) + (1.0*a_1**2*a_2**2*a_3**2*a_4**2*(a_2**2*a_3**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_2**2*a_4**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_2**2*a_4**2*np.sin(phi_2)**2 - a_3**2*a_4**2*np.sin(phi_2)**2 + a_3**2*a_4**2)*np.sin(phi_1)**2/(a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)**2 + 1.0*a_1**2*a_2**2*a_3**2*a_4**2*(a_1**2*a_2**2 + a_1**2*a_3**2*np.sin(phi_3)**2 - a_1**2*a_3**2 - a_1**2*a_3**2*np.sin(phi_3)**2/np.sin(phi_2)**2 + a_1**2*a_3**2/np.sin(phi_2)**2 - a_1**2*a_4**2*np.sin(phi_3)**2 + a_1**2*a_4**2*np.sin(phi_3)**2/np.sin(phi_2)**2 + a_2**2*a_3**2*np.sin(phi_3)**2/np.sin(phi_2)**2 - a_2**2*a_3**2/np.sin(phi_2)**2 - a_2**2*a_3**2*np.sin(phi_3)**2/(np.sin(phi_1)**2*np.sin(phi_2)**2) + a_2**2*a_3**2/(np.sin(phi_1)**2*np.sin(phi_2)**2) - a_2**2*a_4**2*np.sin(phi_3)**2/np.sin(phi_2)**2 + a_2**2*a_4**2*np.sin(phi_3)**2/(np.sin(phi_1)**2*np.sin(phi_2)**2))*np.sin(phi_1)**4*np.sin(phi_2)**2/(a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)**2)*(a_1**2*a_3**2*np.cos(phi_1)**2*np.cos(phi_3)**2 - a_1**2*a_3**2*np.cos(phi_1)**2 - a_1**2*a_3**2*np.cos(phi_3)**2 + a_1**2*a_3**2 - a_1**2*a_4**2*np.cos(phi_1)**2*np.cos(phi_3)**2 + a_1**2*a_4**2*np.cos(phi_3)**2 - a_2**2*a_3**2*np.cos(phi_1)**2*np.cos(phi_2)**2*np.cos(phi_3)**2 + a_2**2*a_3**2*np.cos(phi_1)**2*np.cos(phi_2)**2 + a_2**2*a_4**2*np.cos(phi_1)**2*np.cos(phi_2)**2*np.cos(phi_3)**2 - a_3**2*a_4**2*np.cos(phi_1)**2*np.cos(phi_2)**2 + a_3**2*a_4**2*np.cos(phi_1)**2)/((a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)*np.sin(phi_1)**2) + (1.0*a_1**2*a_2**2*a_3**2*a_4**2*(a_2**2*a_3**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_2**2*a_4**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_2**2*a_4**2*np.sin(phi_2)**2 - a_3**2*a_4**2*np.sin(phi_2)**2 + a_3**2*a_4**2)*np.sin(phi_1)**2*np.sin(phi_2)**2/(a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)**2 + 1.0*a_1**2*a_2**2*a_3**2*a_4**2*(a_1**2*a_3**2*np.cos(phi_1)**2*np.cos(phi_3)**2 - a_1**2*a_3**2*np.cos(phi_1)**2 - a_1**2*a_3**2*np.cos(phi_3)**2 + a_1**2*a_3**2 - a_1**2*a_4**2*np.cos(phi_1)**2*np.cos(phi_3)**2 + a_1**2*a_4**2*np.cos(phi_3)**2 - a_2**2*a_3**2*np.cos(phi_1)**2*np.cos(phi_2)**2*np.cos(phi_3)**2 + a_2**2*a_3**2*np.cos(phi_1)**2*np.cos(phi_2)**2 + a_2**2*a_4**2*np.cos(phi_1)**2*np.cos(phi_2)**2*np.cos(phi_3)**2 - a_3**2*a_4**2*np.cos(phi_1)**2*np.cos(phi_2)**2 + a_3**2*a_4**2*np.cos(phi_1)**2)*np.sin(phi_1)**2*np.sin(phi_2)**2/(a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)**2)*(a_1**2*a_2**2 + a_1**2*a_3**2*np.sin(phi_3)**2 - a_1**2*a_3**2 - a_1**2*a_3**2*np.sin(phi_3)**2/np.sin(phi_2)**2 + a_1**2*a_3**2/np.sin(phi_2)**2 - a_1**2*a_4**2*np.sin(phi_3)**2 + a_1**2*a_4**2*np.sin(phi_3)**2/np.sin(phi_2)**2 + a_2**2*a_3**2*np.sin(phi_3)**2/np.sin(phi_2)**2 - a_2**2*a_3**2/np.sin(phi_2)**2 - a_2**2*a_3**2*np.sin(phi_3)**2/(np.sin(phi_1)**2*np.sin(phi_2)**2) + a_2**2*a_3**2/(np.sin(phi_1)**2*np.sin(phi_2)**2) - a_2**2*a_4**2*np.sin(phi_3)**2/np.sin(phi_2)**2 + a_2**2*a_4**2*np.sin(phi_3)**2/(np.sin(phi_1)**2*np.sin(phi_2)**2))/(a_1**2*a_2**2*a_3**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 - a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2*np.sin(phi_3)**2 + a_1**2*a_2**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 - a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2*np.sin(phi_2)**2 + a_1**2*a_3**2*a_4**2*np.sin(phi_1)**2 - a_2**2*a_3**2*a_4**2*np.sin(phi_1)**2 + a_2**2*a_3**2*a_4**2)
    else:
        raise ValueError(f"Only scalar curvatures of ellipsoids of dimensions 2 and 3 implemented.")

    return np.array(curvature).astype(np.float64)
