import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import norm
from numba import jit, prange, int32

        
@jit(nopython=True, parallel=False, fastmath=False)    
def UniBsplineBasic(k, n, t):
    """
    k = order
    n = number of control points
    t = evaluation position
    """
    # Total number of knots = n + k (MATLAB indexing)
    knots = np.zeros(n + k)

    # Number of uniform interior knots
    num_uniform = n - k + 2  # exactly what MATLAB produces

    # Fill the interior knots exactly as in MATLAB:
    knots[k-1 : k-1 + num_uniform] = np.linspace(0, 1, num_uniform)

    # Fill the final k knots with 1
    knots[n : n + k] = 1.0

    # Basis array
    Np = np.zeros(n + k - 1)

    # Special cases
    if t == 0:
        Np[0] = 1
    elif t == 1:
        Np[n-1] = 1
    else:
        # Initialize piecewise constants N_{i,1}
        for p in range(n + k - 1):
            if knots[p] <= t < knots[p+1]:
                Np[p] = 1

        # de Boor-Cox recursion
        for q in range(2, k+1):  # degree 2..k
            for p in range(n + k - q):
                left = 0
                if Np[p] != 0 and knots[p+q-1] != knots[p]:
                    left = (t - knots[p])/(knots[p+q-1] - knots[p]) * Np[p]

                right = 0
                if Np[p+1] != 0 and knots[p+q] != knots[p+1]:
                    right = (knots[p+q] - t)/(knots[p+q] - knots[p+1]) * Np[p+1]

                Np[p] = left + right

    # Return only the first n basis functions
    return Np[:n]



@jit(nopython=True, parallel=False, fastmath=False)    
def BSplineApproximation(xr, yr, zr, h, iteration):

    n = len(xr)

    Bx = np.zeros((n,h))
    By = np.zeros((n,h))

    for p in prange(n):
        Bx[p,:] = UniBsplineBasic(4, h, xr[p])    
        By[p,:] = UniBsplineBasic(4, h, yr[p])    
        
    
#    Bx = np.array([UniBsplineBasic(4, h, xr[p]) for p in range(n)])
#    By = np.array([UniBsplineBasic(4, h, yr[p]) for p in range(n)])

    Lambda = np.zeros(h*h)
    ind = 0
    for i in range(h):
        for j in range(h):
            Lambda[ind] = np.sum(Bx[:, i] * By[:, j])
            ind += 1

    mu = 1 / np.max(Lambda)

    zc = np.zeros((h, h))
    delta_lplus1 = np.copy(zr)

    for _ in range(iteration):
        delta_l = np.copy(delta_lplus1)

        Delta = np.zeros((h, h))
        for i in range(h):
            for j in range(h):
                Delta[i, j] = mu * np.sum(Bx[:, i] * By[:, j] * delta_l)

        zc += Delta

        for p in range(n):
            delta_lplus1[p] = zr[p] - Bx[p] @ zc @ By[p].T

#    error = np.abs(
#        (norm(delta_lplus1)**2 - norm(delta_l)**2) / norm(delta_lplus1)**2
#    )

    xc = np.linspace(0, 1, h)
    yc = np.linspace(0, 1, h)

#    return xc, yc, zc, error
    return xc, yc, zc


@jit(nopython=True, parallel=False, fastmath=False)    
def eminencef_fast(x, y, persistence, epsilon):

    dx = x[:, None] - x[None, :]
    dy = y[:, None] - y[None, :]
    dist2 = dx**2 + dy**2

    num = np.sum(dist2 <= epsilon**2, axis=1)
    weight = persistence * num

    return weight, num


@jit(nopython=True, parallel=False, fastmath=False)    
def eminencef(x, y, persistence, epsilon):

    n = len(x)
    num = np.ones(n)

    for i in range(n):
        for j in range(i+1, n):
            if (x[i]-x[j])**2 + (y[i]-y[j])**2 <= epsilon**2:
                num[i] += 1
                num[j] += 1

    return persistence * num, num


@jit(nopython=True, parallel=False, fastmath=False)    
def make_PersSplines_vec(dgm,m_b,M_b,M_p,h,sig=1e-10,iteration=100):
                         
    x = dgm[:,0]
    y = dgm[:,1]
    pers = y - x

    xr = (x-m_b)/(M_b-m_b)
    yr = pers/M_p
    zr, _= eminencef_fast(xr, yr, pers, sig)
    
    xc,yc,zc = BSplineApproximation(xr,yr,zr,h,iteration)

    return zc








    