import numpy as np
from numba import jit, prange, int32
from scipy.interpolate import LSQSphereBivariateSpline as SBS
from scipy.interpolate import bisplev

import itertools

import pyshtools

def decompose_spherical_function(theta_grid, phi_grid, f_vals):
    """
    Decompose f(theta, phi) defined on a regular grid into spherical harmonics
    using pyshtools.

    Parameters
    ----------
    theta : array (Nθ,)      # colatitude, in [0, π]
    phi   : array (Nφ,)      # longitude, in [0, 2π)
    f_vals : array (Nθ, Nφ)  # samples f(θ_i, φ_j)
    lmax : int               # maximum spherical-harmonic degree

    Returns
    -------
    coeffs : SHCoeffs object (contains C_{ℓm}, S_{ℓm})
    """

    # Convert to latitude (pyshtools convention)
    # theta: 0 = north pole, pi = south pole
    # lat:   +90° = north pole, -90° = south pole
    lat = (np.pi/2 - theta_grid) * 180/np.pi   # degrees, decreasing

    # Convert to degrees
    lon = phi_grid * 180/np.pi                 # degrees

    # pyshtools expects f_vals shaped (nlat, nlon) with lat decreasing from +90 to -90
    # If your theta is increasing, invert latitude order:
    if lat[0] < lat[-1]:        # increasing? -> reverse
        lat = lat[::-1]
        f_vals = f_vals[::-1, :]

    # Create SHGrid
    grid = pyshtools.SHGrid.from_array(f_vals, grid='DH')

    # Expand into spherical harmonics
    coeffs = grid.expand(normalization='ortho')

    return coeffs
    
    
@jit(nopython=True, parallel=False, fastmath=False)    
def persistence(D):
    X = D[:,0]
    Y = D[:,0]

    pers = np.abs(Y - X)
    norm = np.sqrt(1 + X*X + Y*Y)
    return pers/(2*norm)



def make_weighting(K=0.3,alpha=1):
    
    """
    K=0 -> returns the function constantly equal to 1
    """
    if K==0:
        @jit(nopython=True, parallel=False, fastmath=False)    
        def step(dgm):
            return np.ones_like(dgm[:,0])
    elif K>0:
        @jit(nopython=True, parallel=False, fastmath=False)    
        def step(dgm):

            # X, Y are 1D arrays of same shape (D[:,0], D[:,1])
            
            X = dgm[:,0]
            Y = dgm[:,1]

            pers = np.abs(Y - X)
            norm = np.sqrt(1 + X*X + Y*Y)
        
            aux = pers / norm
            ratio = aux / (2*K)
#            return np.tanh(ratio**alpha)
        
            return np.arctan(ratio**alpha) / (np.pi/2)
    else: 
        step = persistence

    return step
    

@jit(nopython=True, fastmath=False)
def preprocess_dgm(dgm, weighting = persistence):
    # D has shape (N, 2)

    # 1. Compute weights for all points in D
    w = weighting(dgm)

    # 2. Precompute the 3D vectors [1, q_x, q_y] for all q
    V = np.empty((dgm.shape[0], 3))
    V[:,0] = 1.0
    V[:,1] = dgm[:,0]
    V[:,2] = dgm[:,1]

    return w, V
    

@jit(nopython=True, parallel=False, fastmath=False)    
def from_DGMS_to_H(DGMS,pts=None,weighting = persistence,n_theta = 100, n_phi = 200):

    if pts is None:
        theta_grid = np.linspace(0,np.pi,n_theta)
        phi_grid = np.linspace(0,2*np.pi,n_phi)
    
        phiv,thetav = meshgrid(phi_grid,theta_grid)
        grid = np.zeros((n_theta,n_phi,2))
    
        grid[:,:,0] = phiv
        grid[:,:,1] = thetav
    
        pts = spherical_to_card_vec(grid)
        
    OUT = np.zeros((len(DGMS),n_theta,n_phi))

    for i in prange(len(DGMS)):
        OUT[i,:,:]=make_H_wrap(DGMS[i],pts,weighting,n_theta = n_theta, n_phi = n_phi)

    return OUT
    

@jit(nopython=True, parallel=False, fastmath=False)    
def make_H_wrap(dgm, pts=None, weighting = persistence, n_theta = 100, n_phi = 200):

    if pts is None:
        theta_grid = np.linspace(0,np.pi,n_theta)
        phi_grid = np.linspace(0,2*np.pi,n_phi)
    
        phiv,thetav = meshgrid(phi_grid,theta_grid)
        grid = np.zeros((n_theta,n_phi,2))
    
        grid[:,:,0] = phiv
        grid[:,:,1] = thetav
    
        pts = spherical_to_card_vec(grid)

    w, V = preprocess_dgm(dgm, weighting)
    H = make_H_numba(pts, V, w)

    return H


@jit(nopython=True, parallel=True, fastmath=False)    
def make_H_numba(pts, V, w):
    # pts: (H, W, 3)
    # V:   (N, 3)
    # w:   (N,)

    H, W, _ = pts.shape
    HW = H * W

    # (H*W, 3)
    pts_flat = pts.reshape(HW, 3)

    # Matrix multiplication → (HW, N)
    dots = pts_flat @ V.T

    # Clip negative values
    dots = np.where(dots > 0, dots, 0.0)

    # Weighted sum for each pixel → (HW,)
    h_flat = dots @ w

    # Back to (H, W)
    return h_flat.reshape(H, W)



@jit(nopython=True, parallel=True, fastmath=False)    
def decompose_dgm(dgm, weighting = persistence, n_theta = 100, n_phi = 200):

    data_dgm = np.zeros((len(dgm),n_theta,n_phi))

    for i in prange(len(dgm)):

        p=dgm[i,:].reshape(1,-1)
        data_dgm[i,:,:] = make_H_wrap(p, weighting, n_theta = n_theta, n_phi = n_phi)

    return data_dgm


def make_PD_grid(DGMS, len_grid = 100):
    
    M = np.max([np.max(dgm) for dgm in DGMS])
    m = np.min([np.min(dgm) for dgm in DGMS])
    
    aux = np.linspace(m,M,len_grid)
    rect = itertools.product(aux, aux) 

    PD_grid = np.array([np.array([b, d]) for (b,d) in rect if b<d])
    
    return PD_grid

@jit(nopython=True, parallel=True, fastmath=True)    
def meshgrid(x, y):
    xx = np.empty(shape=(y.size, x.size), dtype=x.dtype)
    yy = np.empty(shape=(y.size, x.size), dtype=y.dtype)
    for j in prange(y.size):
        for k in prange(x.size):
            xx[j,k] = x[k] 
            yy[j,k] = y[j] 
    return xx, yy


@jit(nopython=True, parallel=False, fastmath=True)    
def make_dV(n_theta = 100, n_phi = 200):
    
    theta_grid = np.linspace(0,np.pi,n_theta)
    phi_grid = np.linspace(0,2*np.pi,n_phi)
    
    dtheta = theta_grid[1]-theta_grid[0]
    dphi = phi_grid[1]-phi_grid[0]

    _,thetav = meshgrid(phi_grid,theta_grid)

    dV = np.sin(thetav)* dtheta * dphi
    
    return dV


@jit(nopython=True, parallel=False, fastmath=True)    
def spherical_to_card(p):
    
    theta,phi = p
    x = np.sin(theta)*np.cos(phi)
    y = np.sin(theta)*np.sin(phi)
    z = np.cos(theta)
    
    return np.array([x,y,z])


@jit(nopython=True, parallel=False, fastmath=True)    
def spherical_to_card_vec(grid):
    
    theta = grid[:,:,0]
    phi = grid[:,:,1]
    
    x = np.sin(theta)*np.cos(phi)
    y = np.sin(theta)*np.sin(phi)
    z = np.cos(theta)
    
    cart = np.zeros((grid.shape[0],grid.shape[1],3))
    
    cart[:,:,0] = x
    cart[:,:,1] = y
    cart[:,:,2] = z
    
    return cart




