import numpy as np
from scipy.special import sph_harm
import scipy.sparse
import scipy.sparse.linalg
import matplotlib.pyplot as plt

def cart2sp(x):
    # x is n x 3 vector of Cartesian coordinates on the sphere
    # phi is elevation/inclination angle \in [0, pi]
    # theta is azimuth angle \in [0, 2pi)
    # notation in keeping with scipy sph_harm
    x = np.divide(x, np.linalg.norm(x, axis=1, keepdims=True))
    phi= np.zeros(np.shape(x)[0])
    theta = np.zeros(np.shape(x)[0])
    phi = np.arccos(x[:,2])
    theta = np.sign(x[:,1]) * np.arccos(np.divide(x[:, 0], np.sqrt(np.power(x[:, 0], 2) + np.power(x[:,1], 2))))
    return theta, phi

def real_sph_harm(m, n, x):
    # x is n x 3 vector of Cartesian coordinates on the sphere
    theta, phi = cart2sp(x)
    if m == 0:
        return np.real(sph_harm(m, n, theta, phi))
    elif m < 0:
        return np.real(1j * np.sqrt(0.5) * (sph_harm(m, n, theta, phi) - np.power(-1, np.abs(m)) * sph_harm(-m, n, theta, phi)))
    else:
        return np.real(np.sqrt(0.5) * (sph_harm(-m, n, theta, phi) + np.power(-1, m) * sph_harm(m, n, theta, phi)))

def get_Ln(x, eps, d, float_tol=1e-12):
    # x is n x m (ambient dimension) vector of Cartesian coordinates
    N = np.shape(x)[0]
    G = np.matmul(x, x.T)
    D = np.reshape(np.diag(G), (1, -1)) + np.reshape(np.diag(G), (-1, 1)) - 2 * G # matrix of squared (Euclidean) distances between points
    W = np.power(eps, -d/2) * np.exp(-D / eps)
    W[W < float_tol] = 0
    nz_idxs = np.nonzero(W)
    W_sparse = scipy.sparse.coo_matrix((W[nz_idxs], (nz_idxs[0], nz_idxs[1])), shape=(N,N)).tocsr()
    # D = np.diag(np.sum(W, axis=1, keepdims=False))
    D_sparse = scipy.sparse.coo_matrix((np.sum(W, axis=1, keepdims=False),(np.array(range(N)),np.array(range(N)))), shape=(N,N)).tocsr()
    Ln = (D_sparse-W_sparse) / (N * eps)
    return Ln

N_trials = 100
N_powers = 10
init_pow = 10
final_pow = 14
N_list = np.floor(np.logspace(init_pow, final_pow, num=N_powers, base=2)).astype(int)
float_tol = 1e-12
kappa=9
err = np.zeros((N_powers, N_trials))
for j, N in enumerate(N_list):
    d=2
    eps = 0.1 * np.power(N, -2/(d+6))
    vals_actual = np.array([0, 2, 2, 2, 6, 6, 6, 6, 6])
    print("\nStarting N =", N, "...", end=' ')
    for i in range(N_trials):
        # generate sampling points
        x = np.random.randn(N, 3)
        x = np.divide(x, np.linalg.norm(x, axis=1, keepdims=True))
        # generate signal
        n_list = [0, 1, 2]
        f = np.zeros(N)
        for n in n_list:
            for m in range(-n, n+1, 1):
                f += real_sph_harm(m, n, x)
        Ln = 17 * get_Ln(x, eps, d, float_tol)
        vals, vecs = scipy.sparse.linalg.eigsh(Ln, k=kappa, which='SM')
        if np.abs(vals[0]) < float_tol:
            vals[0] = 0 # manually set
        # rescale evecs to unit \| \|_G^n norm
        # vecs = vecs * np.sqrt(N)
        zn = np.abs(np.matmul(np.matmul(np.exp(-vals) * vecs, vecs.T),f))
        vecs_actual = np.zeros((N, kappa))
        vecs_actual[:,0] = real_sph_harm(0, 0, x)
        vecs_actual[:,1] = real_sph_harm(-1, 1, x)
        vecs_actual[:,2] = real_sph_harm(0, 1, x)
        vecs_actual[:,3] = real_sph_harm(1, 1, x)
        vecs_actual[:,4] = real_sph_harm(-2, 2, x)
        vecs_actual[:,5] = real_sph_harm(-1, 2, x)
        vecs_actual[:,6] = real_sph_harm(0, 2, x)
        vecs_actual[:,7] = real_sph_harm(1, 2, x)
        vecs_actual[:,8] = real_sph_harm(2, 2, x)
        z = np.abs(np.sum(np.exp(-vals_actual) * vecs_actual, axis=1))
        err[j, i] = np.linalg.norm(zn - z) / np.sqrt(N)
        if (i%10 == 0):
            print(i, end = ' ')
    print("done", end = ' ')
    
err_avg = np.mean(err, axis=1)
err_std = np.std(err, axis=1)

# fit error data
log_err_vec = np.reshape(np.log(err), -1)
log_N_all_vec = np.reshape(np.log(np.repeat(N_list, N_trials)), (-1, 1))
N_all_vec = np.reshape(np.repeat(N_list, N_trials), (-1, 1))
ls_soln_polynomial_model = np.linalg.lstsq(np.hstack((log_N_all_vec, np.ones(np.shape(log_N_all_vec)))), log_err_vec)

print(ls_soln_polynomial_model)

log_N_vec = np.reshape(np.log(N_list), (-1, 1))
recons = np.matmul(np.hstack((log_N_vec, np.ones(np.shape(log_N_vec)))), ls_soln_polynomial_model[0])

# make plot

plt.rcParams['text.usetex'] = True
plt.figure()
plt.errorbar(N_list, err_avg, yerr=err_std,capsize=4, marker="o", LineStyle="none")
plt.plot(N_list, np.exp(recons))
ax=plt.gca()
ax.set_xscale("log")
ax.set_yscale("log")
plt.xlabel("Number of sampled points",fontsize=14)
plt.ylabel("Discretization error",fontsize=14)
ax.legend(('$\mathcal{O}(n^{-0.76})$ fit',"average error $\pm$ 1 standard deviation"))
plt.savefig('error_rate.png', bbox_inches='tight')
plt.show()
