import numpy as np
from numpy._core.multiarray import array as array
from . import typeDef as td

class Simplex():
    def __init__(self, _definingPoints:np.array) -> None:
        assert _definingPoints.shape[1] == 3, "Points must be stored in an Nx3 matrix with each point occupying one row."
        self.corners = _definingPoints
        self.num_corners = len(self.corners)
        self.simplex_dim = len(self.corners)-1

    def get_interior_point(self, weights:np.array) -> np.array:
        assert len(weights) == self.simplex_dim or len(weights) == self.num_corners
        if len(weights) == self.simplex_dim:
            weights = np.concatenate((weights, np.array([1.0-weights.sum()])))
        return np.matmul(weights, self.corners)
    

class Point(Simplex):
    def __init__(self, _pos:np.array) -> None:
        super().__init__(_pos)
        self.pos = _pos


class Line(Simplex):
    def __init__(self, p0:np.array, p1:np.array) -> None:
        assert p0.size == 3
        assert p1.size == 3
        pts = np.concat((p0[np.newaxis, :],
                         p1[np.newaxis, :]), 
                         axis=0) # append points as rows
        super().__init__(pts)

    def point_from_param(self, t:float) -> np.array:
        assert 0 <= t <= 1
        return self.get_interior_point(np.array([t]))

class Triangle(Simplex):
    def __init__(self, p0:np.array, p1:np.array, p2:np.array) -> None:
        assert p0.size == 3
        assert p1.size == 3
        assert p2.size == 3
        pts = np.concat((p0[np.newaxis, :],
                         p1[np.newaxis, :], 
                         p2[np.newaxis, :]),
                         axis=0) # append points as rows
        super().__init__(pts)

    def point_from_param(self, u:float, v:float) -> np.array:
        assert 0 <= u <= 1
        assert 0 <= v <= 1
        return self.get_interior_point(np.array([u, v]))




# test driver code
if 0:
    l = Line(np.array([0,0,0]),
                np.array([1,0,0]))
    print(l.point_from_param(0.75))

    t = Triangle(np.array([0,0,0]),
                np.array([1,0,0]),
                np.array([0,1,0]))
    print(t.point_from_param(0.3, 0.3))
