# Curvature and Local density estimation Functions # 
import numpy as np
import warnings
from scipy.special import gamma

import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# from curv_utils import *
from manifolds import *
from data_utils import *

# Update as needed
manifolds = ['gman_proj', 'gman_vec', 'steifel_proj', 
             'steifel_vec1', 'flag_vec', 'sun_pauli']

kvals_list = {}
for mn in manifolds:
    if 'flag' in mn or 'pauli' in mn:
        kvals_list[mn] = np.logspace(np.log10(20), np.log10(100), num = 8, dtype=int)
    else:
        kvals_list[mn] = np.logspace(np.log10(20), np.log10(200), num = 8, dtype=int)

def build_quadratic_design_matrix(U):
    """
    Construct design matrix for quadratic regression.
    U: (k, d) array of inputs in tangent coordinates
    Returns:
        Phi: (k, p) design matrix where p = 1 + d + d*(d+1)//2
    """
    k, d = U.shape
    p = 1 + d + d * (d + 1) // 2
    Phi = np.zeros((k, p))
    Phi[:, 0] = 1  # constant term
    Phi[:, 1:1 + d] = U  # linear terms
    idx = 1 + d
    for i in range(d):
        for j in range(i, d):
            Phi[:, idx] = U[:, i] * U[:, j]
            idx += 1
    return Phi

def mean_curvature_quadratic_fit(X, knns, kmax, energy_thresh=0.95, di = None):
    """
    Estimate mean curvature vectors via second-order polynomial fitting.
    
    Parameters:
    - X: (N, D) array of samples
    - knns: (distances, indices), each (N, K)
    - kmax: number of neighbors to use per point
    - energy_thresh: threshold to estimate intrinsic dimension from SVD

    Returns:
    - H_vectors: (N, D) array of mean curvature vectors (in ambient space)
    - H_norms: (N,) array of mean curvature magnitudes
    """
    distances, indices = knns
    N, D = X.shape
    H_vectors = np.zeros((N, D))
    H_norms = np.zeros(N)

    for i in range(N):
        neighbors_idx = indices[i, :kmax]
        neighbors = X[neighbors_idx]
        # x0 = X[i]
        x0 = np.mean(neighbors, axis=0)
        centered = neighbors - x0

        # Step 1: PCA to get tangent and normal spaces
        U, S, Vt = np.linalg.svd(centered, full_matrices=False)
        # print(Vt.shape)
        
        if di is not None:
            energy = np.cumsum(S**2) / np.sum(S**2)
            d = np.searchsorted(energy, energy_thresh) + 1
        else:
            d = di

        T = Vt[:d]     # tangent basis (d x D)
        Nrm = Vt[d:]   # normal basis ((D-d) x D)
        # print(Nrm.shape)
        if Nrm.shape[0] == 0:
            continue  # skip if no normal space found

        # Step 2: Project to local coordinates
        u = centered @ T.T          # shape (k, d)
        v = centered @ Nrm.T        # shape (k, D-d)

        # Step 3: Fit quadratic model v(u)
        Phi = build_quadratic_design_matrix(u)  # (k, p)
        p = Phi.shape[1]

        # For each normal component, fit a model
        H_ambient = np.zeros((D,))  # mean curvature vector in ambient space

        for m in range(v.shape[1]):
            coeffs, _, _, _ = np.linalg.lstsq(Phi, v[:, m], rcond=None)
            # print(coeffs)
            Hm = np.zeros((d, d))
            idx = 1 + d  # starting index of quadratic terms
            for a in range(d):
                for b in range(a, d):
                    Hm[a, b] = coeffs[idx]
                    Hm[b, a] = coeffs[idx] if a != b else Hm[a, b]
                    idx += 1
            trace_Hm = np.trace(Hm)
            H_ambient += (trace_Hm / d) * Nrm[m]

        H_vectors[i] = H_ambient
        H_norms[i] = np.linalg.norm(H_ambient)

    return H_norms, H_vectors

def sa(d):
    return 2*np.pi**(d/2)/gamma(d/2) # gamma : 

# print(sa(2), 2*np.pi)
# print(sa(3), 4*np.pi)

def local_density_estimation(X, knns, kmax, di=2, comb='mean'):
    """
    Estimate local density of samples using inverse power of distance.

    Parameters:
    - X : np.ndarray of shape (N, D)
        The data samples.
    - knns : tuple of (dists, indices)
        Output of kneighbors: distances and indices of shape (N, kmax)
    - kmax : int
        The maximum number of neighbors considered.
    - di : int or None, optional
        Intrinsic dimension. If None, defaults to 1.
    - comb : str, optional
        How to combine densities across neighbors. Options: 'mean', 'median'.

    Returns:
    - densities : np.ndarray of shape (N,)
        Estimated local densities.
    """
    if kmax < 5:
        warnings.warn("kmax < 5: may be too low for reliable density estimation.")

    if di is None:
        di = 1

    dists, _ = knns
    N = X.shape[0]

    kmin = max(kmax // 4, 4)
    slice_dists = dists[:, kmin:kmax]  # shape (N, kmax - kmin)
    # print(slice_dists.shape, kmin, kmax, kmax-kmin)
    
    if slice_dists.shape[1] == 0:
        return np.full((N,), np.nan)
        # raise ValueError("Not enough neighbors to compute density. Increase kmax.")

    densities = np.arange(kmin, kmax)[np.newaxis,:] / sa(di) / slice_dists**di  # broadcasting: shape (N, kmax - kmin)

    if comb =='mean':
        return np.mean(densities, axis=1)
    elif comb == 'median':
        return np.median(densities, axis=1)
    else:
        return densities

def data_gen_for_curv(mn='', params=(1,2,0), N=10):

    if 'gman' in mn:
        d1,d2,seed=params
        di = d1*(d1-1)//2 - d2*(d2-1)//2 - (d1-d2)*(d1-d2-1)//2
        if mn == 'gman_vec':
            info_dict = data_gen_gman_vec(N=N, params=params)
        elif mn == 'gman_proj':
            info_dict = data_gen_gman_proj(N=N, params=params)
        else:
            info_dict = {'lides' : np.full((N,), np.nan)}
    elif 'steifel' in mn:
        d1,d2,seed= params
        di = d1*(d1-1)//2 - d2*(d2-1)//2
        if mn == 'steifel_vec1':
            info_dict = data_gen_steifel_vec1(N=N, params=params)
        elif mn == 'steifel_proj':
            info_dict = data_gen_steifel_proj(N=N, params=params)
        else:
            info_dict = {'lides' : np.full((N,), np.nan)}
    elif 'flag' in mn:
        d1,d2,d3,seed = params 
        di = d1*(d1-1)//2 - d2*(d2-1)//2 - d3*(d3-1)//2 - (d1-d2-d3)*(d1-d2-d3-1)//2
        if mn == 'flag_vec':
            info_dict = data_gen_flag_vec(N=N, params=params)
        else:
            info_dict = {'lides' : np.full((N,), np.nan)}
    elif 'pauli' in mn:
        d1, seed =params
        di = d1**2
        info_dict = data_gen_sun_mod_pauli(N=N, params=params)
    else:
        raise ValueError(f'Manifold {mn} not included.')

    return info_dict, di
# data_gen_for_curv()