import jax.numpy as jnp
import jax.random as jr
from jax import grad, jit, vmap, jacfwd
import scipy as sp
from scipy import fft, sparse, interpolate
import numpy as np
import pde
from typing import NamedTuple

def make_rbf(bandwidth, operator):
    return RBF(bandwidth=bandwidth, op=operator)

def make_aniso_rbf(bandwidth, scale, theta, operator):
    return AnisoRBF(bandwidth=bandwidth, scale=scale, theta=theta, op=operator)
def make_diffusion_field(n_spikes, dim, key):
    spikes = jr.normal(key, shape=(n_spikes, dim))
    return lambda x: jnp.sin(2 * jnp.pi * jnp.sum(spikes @ x.reshape((dim, -1)), axis=0))

class Convection():
    def __init__(self, boundary_condition, beta=1):
        self.beta = beta
        self.bc = boundary_condition

    def apply(self, func):
        return lambda x_: jnp.dot(grad(func)(x_), jnp.array([self.beta, 1]))

    def eval_solution(self, X):
        eval_locs = X[:, 0] - self.beta * X[:,1]
        return self.bc(eval_locs).reshape((-1, 1))

    def eval_forcing(self, X):
        return jnp.zeros(X.shape[0]).reshape((-1, 1))

class Poisson2D():
    def __init__(self, forcing_expr=None, N_grid_fem=200):
        '''
        Forcing must be a pde.ScalarField
        :param forcing:
        :param N_grid_fem:
        '''

        grid = pde.CartesianGrid([[-1, 1], [-1, 1]], N_grid_fem)

        if forcing_expr is None:
            forcing_expr = "10*(1 + sin(6.28*x) * sin(6.28*y))"
        self.forcing_expr = forcing_expr

        self.forcing = pde.ScalarField.from_expression(grid, forcing_expr)

        bc = [{"value":0}, {"value": 0}]
        result = pde.solve_poisson_equation(self.forcing, [bc, bc])

        grid = result.grid.cell_coords.reshape((N_grid_fem**2, 2))
        sol = result.data.flatten()

        self._sol = sol
        self._sol_interpolator = sp.interpolate.NearestNDInterpolator(x=grid, y=sol)
        #self._sol = result
        #cheb_interp = cheb.ChebInterpolator2D(grid=result.grid.cell_coords.reshape((N_grid_fem**2, 2)), vals=result.data.flatten(), N_grid=N_grid_cheb)
        #self.sol = cheb_interp

    def apply(self, func):
        return lambda x_: jnp.trace(jacfwd(lambda x: grad(func)(x))(x_))

    def eval_solution(self, X):
        #return self.sol.eval(X).reshape((-1, 1))

        return self._sol_interpolator(X).reshape((-1, 1))

    def eval_forcing(self, X):
        if self.forcing_expr == '10*(1 + sin(6.28*x) * sin(6.28*y))':
            return 10*(1 + jnp.sin(6.28 * X[:, 0]) * jnp.sin(6.28 * X[:, 1])).reshape((-1, 1))
        else:
            raise ValueError("THE CURRENT IMPLEMENTATION DOES NOT SUPPORT ARBITRARY FORCING FUNCTIONS.")




class Kernel:
    def __init__(self, kfunc, op):
        self.kfunc = kfunc
        self.op = op
        self.__call__ = vmap(self.kfunc)

        kfy = lambda x: lambda y: self.kfunc(x, y)
        self.hfunc = lambda x, y: op.apply(kfy(x))(y)
        self.gfunc = lambda x, y: op.apply(lambda x_: op.apply(lambda y_: self.kfunc(x_, y_))(y))(x)

    def _gramify(self, X, Y):
        return jnp.stack(jnp.meshgrid(X, Y), axis=-1).reshape((len(X) * len(Y), -1))

    def gram(self, X, Y):
        g = self._gramify(X, Y)
        return self.__call__(g[:, 0], g[:, 1]).reshape((len(X), len(Y)))

    def K(self, X, Y):
        return vmap(lambda x_: vmap(lambda y_: self.kfunc(x_, y_))(Y))(X)

    def H(self, X, Y):
        return vmap(lambda x_: vmap(lambda y_: self.hfunc(x_, y_))(Y))(X)
    def G(self, X, Y):
        return vmap(lambda x_: vmap(lambda y_: self.gfunc(x_, y_))(Y))(X)


class AnisoRBF(Kernel):
    def __init__(self, bandwidth, theta, scale, op):
        c = jnp.cos(theta)
        s = jnp.sin(theta)
        R = jnp.array([[c, -s], [s, c]])
        L = jnp.diag(jnp.array([scale**2, 1/scale**2]))
        prec = jnp.linalg.inv(R @ L @ R.T)
        kfunc = lambda x, y: jnp.exp(- 0.5 * jnp.dot(x-y, prec@(x-y)) / bandwidth**2 )
        super().__init__(kfunc, op)

class Cauchy(Kernel):
    def __init__(self, variance, op, eps=1e-8):
        kfunc = lambda x, y: (2 * jnp.pi)**(-1) * jnp.log(jnp.sum((x-y)**2/variance+eps)**(0.5))
        super().__init__(kfunc, op)
        self.variance = variance

class RBF(Kernel):
    def __init__(self, bandwidth, op):
        kfunc = lambda x, y: jnp.exp(- 0.5 * jnp.sum((x-y)**2 / bandwidth**2))
        super().__init__(kfunc, op)
        self.bandwidth = bandwidth
        '''
        if type(op) == Poisson2D:
            def hfunc(x, y):
                sdiff = jnp.linalg.norm(x-y)**2 / bandwidth**2
                return (sdiff - 2) * kfunc(x, y) / bandwidth**2
            def gfunc(x, y):
                sdiff = jnp.linalg.norm(x-y)**2 / bandwidth**2
                return ((sdiff - 2)**2 - 4 * sdiff + 4) * kfunc(x, y) / bandwidth**4
            self.hfunc = hfunc
            self.gfunc = gfunc
        '''
