import numpy as np
import pyvista as pv


def sample_surface(mesh, n_samples):
    # Ensure it's a triangulated surface
    surf = mesh.triangulate()

    faces = surf.faces.reshape((-1, 4))[:, 1:]  # triangle vertex indices
    vertices = surf.points  # vertex coordinates
    n_faces = faces.shape[0]

    # Compute area of each triangle
    v0 = vertices[faces[:, 0]]
    v1 = vertices[faces[:, 1]]
    v2 = vertices[faces[:, 2]]
    tri_areas = 0.5 * np.linalg.norm(np.cross(v1 - v0, v2 - v0), axis=1)

    # Sample triangles proportional to their area
    tri_probs = tri_areas / tri_areas.sum()
    sampled_face_indices = np.random.choice(n_faces, size=n_samples, p=tri_probs)

    # Sample barycentric coordinates
    r1 = np.sqrt(np.random.rand(n_samples))
    r2 = np.random.rand(n_samples)
    a = 1 - r1
    b = r1 * (1 - r2)
    c = r1 * r2

    # Get triangle vertices
    f = faces[sampled_face_indices]
    pts = (
        a[:, None] * vertices[f[:, 0]]
        + b[:, None] * vertices[f[:, 1]]
        + c[:, None] * vertices[f[:, 2]]
    )

    return pts


class Object:
    def __init__(self, name: str, n_samples=200):
        """
        Get a PyVista object by name.
        """
        self.name = name
        self.n_samples = n_samples
        if self.name == "arrow":
            self.data = np.array(
                [
                    [0, 0],
                    [0.25, 0],
                    [0.25, 0.25],
                    [0.25, 0.5],
                    [0.5, 0.5],
                    [0.25, 0.75],
                    [0, 1],
                    [-0.25, 0.75],
                    [-0.5, 0.5],
                    [-0.25, 0.5],
                    [-0.25, 0.25],
                    [-0.25, 0],
                    [0, 0],
                ],
                dtype=np.float32,
            )

        elif self.name == "half_arrow":
            self.data = np.array(
                [
                    [0, 0],
                    [0.25, 0],
                    [0.25, 0.25],
                    [0.25, 0.5],
                    [0.5, 0.5],
                    [0.25, 0.75],
                    [0, 1],
                ],
                dtype=np.float32,
            )

        elif self.name == "pv_arrow":
            arrow = pv.Arrow(tip_length=0.25, tip_radius=0.25, shaft_radius=0.1)
            self.data = sample_surface(arrow, self.n_samples)
        elif self.name == "pv_half_arrow":
            arrow = pv.Arrow(tip_length=0.25, tip_radius=0.25, shaft_radius=0.1)
            half_arrow = arrow.clip("z", invert=False)
            half_arrow.fill_holes(1000, inplace=True)

            self.data = sample_surface(half_arrow, self.n_samples)
        elif self.name == "pv_half_arrow_4pts":
            self.data = np.array(
                [[0, 0, 0], [1.0, 0, 0], [0.8, 0.15, 0.15], [0.8, -0.15, 0.15]],
                dtype=np.float32,
            )
        elif self.name == "irreg_tet":
            self.data = np.array(
                [[0, 0, 0], [1.0, 0, 0], [0.5, 1, 0.2], [0.3, 0.4, 1.1]],
                dtype=np.float32,
            )
        else:
            raise ValueError(f"Object '{self.name}' is not supported.")
