import jax.numpy as jnp
import jax.random as jr
from jax import grad, jit, vmap, jacfwd
import scipy as sp
from scipy import fft, special
import numpy as np
#import orthax as ox
import math as mt
import scipy.linalg as scla
import sklearn as sk
import sklearn.neighbors

'''
def gridpts(N, with_weights=False):
    # chebyshev nodes of second kind
    # see "A MATLAB Differentiation Matrix Suite" by Weideman and Reddy
    x = jnp.sin(np.pi * ((N - 1) - 2 * jnp.linspace(N - 1, 0, N)) / (2 * (N - 1)))  # W&R way
    if with_weights:
        weights = (jnp.pi / (N-1)) * jnp.sin( jnp.pi * jnp.arange(N) / (N-1) )**2
        return x[::-1], weights
    else:
        return x[::-1]
    '''


def gridpts(N, with_weights=False):
    x = jnp.cos(jnp.pi * (2 * jnp.arange(N) + 1) / (2 * N))
    if with_weights:
        weights = (2 / N) * np.ones_like(x)
        return x, weights
    else:
        return x


def collocate_D(N):
    M = 1
    DM = np.zeros((M, N, N))

    # n1 = (N/2); n2 = round(N/2.)     # indices used for flipping trick [Original]
    n1 = mt.floor(N / 2)
    n2 = mt.ceil(N / 2)  # indices used for flipping trick [Corrected]
    k = np.arange(N)  # compute theta vector
    th = k * np.pi / (N - 1)

    # Assemble the differentiation matrices
    T = np.tile(th / 2, (N, 1))
    DX = 2 * np.sin(T.T + T) * np.sin(T.T - T)  # trigonometric identity
    DX[n1:, :] = -np.flipud(np.fliplr(DX[0:n2, :]))  # flipping trick
    DX[range(N), range(N)] = 1.  # diagonals of D
    DX = DX.T

    C = scla.toeplitz((-1.) ** k)  # matrix with entries c(k)/c(j)
    C[0, :] *= 2
    C[-1, :] *= 2
    C[:, 0] *= 0.5
    C[:, -1] *= 0.5

    Z = 1. / DX  # Z contains entries 1/(x(k)-x(j))
    Z[range(N), range(N)] = 0.  # with zeros on the diagonal.

    D = np.eye(N)  # D contains differentiation matrices.
    for ell in range(M):
        D = (ell + 1) * Z * (C * np.tile(np.diag(D), (N, 1)).T - D)  # off-diagonals
        D[range(N), range(N)] = -np.sum(D, axis=1)  # negative sum trick
        DM[ell, :, :] = D  # store current D in DM
    return jnp.asarray(DM[0])


def a(x):
    return jnp.asarray(x)


def to_symmetric(gridpts):
    return 2 * gridpts - 1


def to_unit(gridpts):
    return 0.5 + gridpts / 2


def collocate_M(a):
    return jnp.diag(a)


def dcht(u):
    N = len(u)
    scale = jnp.ones((N,)) + jnp.concatenate((a([1]), a([0] * (N - 2)), a([1])))
    return sp.fft.dct(u / (N - 1), norm="backward", type=1) / scale


def idcht(u):
    N = len(u)
    scale = jnp.ones((N,)) + jnp.concatenate((a([1]), a([0] * (N - 2)), a([1])))
    return (N - 1) * sp.fft.idct(u * scale, norm="backward", type=1)


def eval_on_grid(func, N_grid, dim, flattened=True, use_grid=None):
    if use_grid is None:
        one_d_grid = gridpts(N_grid)
    else:
        one_d_grid = use_grid
        N_grid = len(one_d_grid)

    grid = jnp.stack(jnp.meshgrid(*dim * [one_d_grid]), axis=-1).reshape((-1, dim))
    func = vmap(func, in_axes=0)
    if flattened:
        return func(grid)
    else:
        return func(grid).reshape(dim * [N_grid])


if __name__ == "__main__":

    key = jr.key(641)
    N = 50
    xn = gridpts(N)

    for m in range(0, N):
        uhat = jnp.zeros((N,)).at[m].set(1)
        tm = jnp.cos(m * jnp.arccos(xn))
        tm_ = idcht(uhat)
        print(f"Mode {m} absolute error {jnp.max(jnp.abs(tm - tm_))}")

    D = collocate_D(xn)
    u = jnp.cos(xn)
    upr = -jnp.sin(xn)
    upr_ = collocate_D(xn) @ u
    assert jnp.allclose(upr, upr_)
    u = jr.normal(key, shape=(N,))

    assert jnp.allclose(u, idcht(dcht(u)))