import jax
from jax import numpy as jnp


DTYPE = jnp.float32


def triangle_mesh_quad_rule(n, quadrature_fn, mesh):
    """
    Inputs:
        n: int
        --> number of 1d quadrature pts, which will first be mapped to the reference triangle (3n^2 pts),
                then to the mesh, totaling len(mesh)*3n^2 pts. The reference triangle
        quadrature_fn: callable
        --> function that returns (t,w), given n (the degree) for the domain [-1,1]
                as like numpy.polynomial.legendre.leggauss
        mesh: np.array
        --> a list of triangles, as defined by vertex coordinates,
                i.e. shape (10,3,2) (10 triangles, 3 vertices per triangle)
    Returns:
        a tuple containing two np arrays: the quadrature nodes and weights for the mesh.
    """

    ref_triangle = jnp.array([[0, 0], [1, 0], [0.5, jnp.sqrt(3) / 2]])
    ### get the reference triangle quad rule
    tri_quad_rule = triangle_quad_rule(n, quadrature_fn, triangle=ref_triangle)

    def triangle_area(triangle):
        x1, x2, x3 = triangle
        return 0.5 * jnp.abs(
            x1[0] * x2[1]
            + x2[0] * x3[1]
            + x3[0] * x1[1]
            - x1[1] * x2[0]
            - x2[1] * x3[0]
            - x3[1] * x1[0]
        )

    def barycentric_coordinates(x, triangle):
        A, B, C = triangle
        area = triangle_area((A, B, C))
        alpha = triangle_area((x, B, C)) / area
        beta = triangle_area((A, x, C)) / area
        gamma = triangle_area((A, B, x)) / area
        return jnp.array([alpha, beta, gamma])

    ### map a single coordinate from the ref triangle to an arbitrary triangle
    def coord_ref_triangle_to_triangle(
        x,
        triangle,
        ref_triangle=jnp.array([[0, 0], [1, 0], [0.5, jnp.sqrt(3) / 2]]),
    ):
        alpha, beta, gamma = barycentric_coordinates(x, ref_triangle)
        x1, x2, x3 = triangle
        return alpha * x1 + beta * x2 + (1 - alpha - beta) * x3

    ### jacobian determinate of the transformation (the area change)
    def detj_ref_triangle_to_triangle(
        triangle, ref_triangle=jnp.array([[0, 0], [1, 0], [0.5, jnp.sqrt(3) / 2]])
    ):
        return triangle_area(triangle) / triangle_area(ref_triangle)

    ### map the entire quad rule to a single arbitrary triangle
    def quad_rule_ref_triangle_to_triangle(quad_rule, triangle):
        t, w = quad_rule
        updated_w = w * detj_ref_triangle_to_triangle(triangle)
        updated_t = jax.vmap(coord_ref_triangle_to_triangle, in_axes=[0, None])(
            t, triangle
        )
        return (updated_t, updated_w)

    ### map the quad rule to all triangles in the mesh
    def quad_rule_ref_triangle_to_mesh(quad_rule, mesh):
        mesh_quad_rules = jax.vmap(
            quad_rule_ref_triangle_to_triangle, in_axes=[None, 0]
        )(quad_rule, mesh)
        return mesh_quad_rules

    mesh_quad_rule_t, mesh_quad_rule_w = quad_rule_ref_triangle_to_mesh(
        tri_quad_rule, mesh
    )
    mesh_quad_rule = (
        mesh_quad_rule_t.reshape(-1, 2).astype(DTYPE),
        mesh_quad_rule_w.flatten()[:, None].astype(DTYPE),
    )
    return mesh_quad_rule


def triangle_quad_rule(
    n, quadrature_fn, triangle=jnp.array([[0, 0], [1, 0], [0.5, jnp.sqrt(3) / 2]])
):
    """
    Defines a quadrature rule for the reference triangle ([0,0],[1,0],[0.5,sqrt(3)/2]])
    by creating rules for 3 quadrilaterals within this triangle.
    """

    ### tensor-product approach to defining the rule for the unit-square
    def quad_rule_2d(n, quadrature_fn):
        ndims = 2
        quad_nodes, quad_weights = quadrature_fn(n)
        a, b = -1, 1  ### old domain
        c, d = 0, 1  ### new domain
        t, w = quad_nodes, quad_weights
        t = (((t - a) * (d - c)) / (b - a)) + c
        det_j = (d - c) / (b - a)
        w *= det_j
        t = jnp.array(jnp.meshgrid(*[t] * ndims))
        t = t.reshape(len(t), -1).T
        w = jnp.outer(*([w] * ndims)).flatten()[:, None]
        quad_rule_2d = (t, w)
        return quad_rule_2d

    quad_rule = quad_rule_2d(n, quadrature_fn)

    ### map a single coordinate in the unit square to an arbitrary quadrilateral (defined by 4 vertices)
    def coord_square_to_quadrilateral(x, quadrilateral):
        x1, x2, x3, x4 = quadrilateral
        xi, eta = x
        psi_1 = lambda xi, eta: (1 - xi) * (1 - eta)
        psi_2 = lambda xi, eta: xi * (1 - eta)
        psi_3 = lambda xi, eta: xi * eta
        psi_4 = lambda xi, eta: (1 - xi) * eta
        return (
            x1 * psi_1(xi, eta)
            + x2 * psi_2(xi, eta)
            + x3 * psi_3(xi, eta)
            + x4 * psi_4(xi, eta)
        )

    ### jacobian determinant of the transformation of a unit square to an arbitrary quadrilateral
    def detj_square_to_quadrilateral(x, quadrilateral):
        J = jax.jacfwd(coord_square_to_quadrilateral, argnums=0)(x, quadrilateral)
        return jnp.linalg.det(J)

    ### using the two functions above to map the unit square quad rule to a quad rule for an arbitrary quadrilat
    def quad_rule_square_to_quadrilateral(quad_rule, quadrilateral):
        t, w = quad_rule
        updated_w = jax.vmap(
            lambda w, t, quadrilateral: w
            * detj_square_to_quadrilateral(t, quadrilateral),
            in_axes=[0, 0, None],
        )(w, t, quadrilateral)
        updated_t = jax.vmap(coord_square_to_quadrilateral, in_axes=[0, None])(
            t, quadrilateral
        )
        return (updated_t, updated_w)

    ### points for an equilateral triangle
    A, B, C = triangle
    O = (A + B + C) / 3
    D = (A + B) / 2
    E = (B + C) / 2
    F = (A + C) / 2
    ### 3 quadrilaterals within the reference triangle
    quadrilaterals = jnp.array([[A, D, O, F], [B, E, O, D], [C, F, O, E]])

    ### map the unit square quad rule to each of these quadrilaterals to make a rule for the reference triangle
    triangle_quad_t, triangle_quad_w = jax.vmap(
        quad_rule_square_to_quadrilateral, in_axes=[None, 0]
    )(quad_rule, quadrilaterals)

    triangle_quad_rule = (
        triangle_quad_t.reshape(-1, 2).astype(DTYPE),
        triangle_quad_w.flatten()[:, None].astype(DTYPE),
    )
    return triangle_quad_rule


def quadrature_unit_hypercube(ndims, n, quadrature_fn):
    """Tensor-product approach to defining a n-dimensional quadrature rule for the unit hypercube"""
    a, b = -1, 1
    c, d = 0, 1
    t, w = quadrature_fn(n)
    t = (((t - a) * (d - c)) / (b - a)) + c
    det_j = (d - c) / (b - a)
    w *= det_j
    t = jnp.array(jnp.meshgrid(*[t] * ndims))
    t = t.reshape(len(t), -1).T  ### i.e. (100, 2) for 2d and 10 quad pts
    w = w if ndims == 1 else jnp.outer(*([w] * ndims)).flatten()
    return t.astype(DTYPE), w[:, None].astype(DTYPE)


def map_quadrature_to_nd(ndims, n, quadrature_fn):
    """Tensor-product approach to defining a n-dimensional quadrature rule"""
    t, w = quadrature_fn(n)
    t = jnp.array(jnp.meshgrid(*[t] * ndims))
    t = t.reshape(len(t), -1).T  ### i.e. (100, 2) for 2d and 10 quad pts
    w = w if ndims == 1 else jnp.outer(*([w] * ndims)).flatten()
    return t.astype(DTYPE), w[:, None].astype(DTYPE)
