import jax
from flax import nnx
from jax import (
    numpy as jnp,
    vmap,  # noqa
)
from jax.debug import breakpoint, print as jprint  # noqa
from jaxtyping import Array, Float

### is it even necessary to do inverse of the ones for init? probably not
from .utils import create_lifted_module, inverse_softplus_perturbed_ones


### kernels :)  ##########################################################################################


class ProductNNKernel(nnx.Module):
    def __init__(
        self,
        base_kernel: int,
        ndims: int,
        *,
        rngs: nnx.Rngs,
    ):
        self.base_kernel = base_kernel(rngs=rngs)

        self.f1 = nnx.Linear(ndims, 1, rngs=rngs)
        self.f2 = nnx.Linear(ndims, 1, rngs=rngs)

    def __call__(self, x, y):
        k_xy = self.base_kernel(x, y)
        f_1 = nnx.vmap(self.f1)(x)
        f_1 = nnx.gelu(f_1)
        f_2 = nnx.vmap(self.f2)(y).T
        f_2 = nnx.gelu(f_2)
        return f_1 * k_xy * f_2


class DiagonalGaussianSpectralMixtureKernel(nnx.Module):
    def __init__(
        self,
        ndims: int,
        q: int,
        *,
        rngs: nnx.Rngs,
    ):
        key1, key2, key3 = rngs.params(), rngs.params(), rngs.params()
        self.q = q
        self.weights = nnx.Param(inverse_softplus_perturbed_ones(key1, (q,)))
        self.freqs = nnx.Param(inverse_softplus_perturbed_ones(key2, (q, ndims)))
        self.inv_scales = nnx.Param(inverse_softplus_perturbed_ones(key3, (q, ndims)))

    ### for two vectors, i.e. make a kernel matrix
    def __call__(self, x, y):
        ### for two individual points
        def eval(x, y, weights, freqs, inv_scales):
            tau = x - y  # (ndims,)
            tau = jnp.repeat(tau[None], repeats=self.q, axis=0)  # (q,ndims)
            gauss = jnp.exp(-2 * (jnp.pi * tau * inv_scales) ** 2)  # (q,ndims)
            cosine = jnp.cos(2 * jnp.pi * tau * freqs)  # (q,ndims)
            mixtures = jnp.prod(gauss * cosine, axis=-1)  # prod over ndims
            k_xy = (weights * mixtures).sum()  # sum over mixtures
            return k_xy

        if x.ndim == 1 or y.ndim == 1:
            ndims = 1
        else:
            ndims = x.shape[-1]

        out_shape = (len(x), len(y))

        ### repeat the vectors
        y = jnp.repeat(
            y[None],
            repeats=x.shape[0],
            axis=0,
        )
        x = jnp.repeat(
            x[:, None],
            repeats=y.shape[1],
            axis=1,
        )

        ### vmap the eval call over all pairs of points in the two vectors
        y, x = y.reshape(-1, ndims), x.reshape(-1, ndims)
        eval_vmap = jax.vmap(eval, in_axes=(0, 0, None, None, None))
        k_xy = eval_vmap(
            x,
            y,
            nnx.softplus(self.weights.value),
            nnx.softplus(self.freqs.value),
            nnx.softplus(self.inv_scales.value),
        ).reshape(out_shape)
        return k_xy


class WendlandC4Kernel(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        key = rngs.params()
        self.scale = nnx.Param(inverse_softplus_perturbed_ones(key, (1,)))

    def __call__(self, x, y):
        def eval(x, y, scale):
            r = jnp.sqrt((x - y) @ (x - y).T) / scale
            return jnp.where(r < 1, ((1 - r) ** 6) * (3 + 18 * r + 35 * r**2), 0)

        if x.ndim == 1 or y.ndim == 1:
            ndims = 1
        else:
            ndims = x.shape[-1]

        out_shape = (len(x), len(y))

        y = jnp.repeat(
            y[None],
            repeats=x.shape[0],
            axis=0,
        )
        x = jnp.repeat(
            x[:, None],
            repeats=y.shape[1],
            axis=1,
        )

        y, x = y.reshape(-1, ndims), x.reshape(-1, ndims)
        eval_vmap = jax.vmap(
            eval,
            in_axes=(
                0,
                0,
                None,
            ),
        )
        k_xy = eval_vmap(
            x,
            y,
            nnx.softplus(self.scale.value),
        ).reshape(out_shape)
        return k_xy


### layers which use such kernels, expecting kernel inits are filled except rngs :)  ####################################################################################


class KernelInterpolate(nnx.Module):
    def __init__(self, kernel: nnx.Module, *, rngs: nnx.Rngs):
        self.kernel = kernel(rngs=rngs)
        self.nugget = nnx.Param(jnp.ones((1,)) * 1e-6)

    def __call__(
        self,
        f_x: Float[Array, "n_t n_func_dim"],
        x: Float[Array, "n_x n_dim"],
        y: Float[Array, "n_y n_dim"],
    ) -> Float[Array, "n_y n_func_dim"]:
        K_xx = self.kernel(x, x) + jnp.identity((x.shape[0])) * self.nugget
        K_yx = self.kernel(y, x)
        c = jnp.linalg.solve(K_xx, f_x)
        f_y = K_yx @ c
        return f_y


class KernelInterpolateLeastSquares(nnx.Module):
    def __init__(
        self, kernel: nnx.Module, centers: Float[Array, "n_c n_dim"], *, rngs: nnx.Rngs
    ):
        self.kernel = kernel(rngs=rngs)
        ### to make kernel neural network, just make these trainable nnx.Params
        self.centers = centers

    def __call__(
        self,
        f_x: Float[Array, "n_t n_func_dim"],
        x: Float[Array, "n_x n_dim"],
        y: Float[Array, "n_y n_dim"],
    ) -> Float[Array, "n_y n_func_dim"]:
        K_xc = self.kernel(x, self.centers)
        K_yc = self.kernel(y, self.centers)

        ### solve least squares problem
        c = jnp.linalg.lstsq(K_xc, f_x)[0]

        f_y = K_yc @ c

        return f_y


class IntegrateAgainstDiagonalMatrixValuedKernel(nnx.Module):
    def __init__(self, base_kernel: nnx.Module, p: int, q: int, *, rngs: nnx.Rngs):
        self.lifted_kernel = create_lifted_module(base_kernel, q, rngs=rngs)
        self.p = p
        self.q = q

    def __call__(
        self,
        f_t: Float[Array, "n_t p"],
        t: Float[Array, "n_t n_d"],
        w: Float[Array, "n_t 1"],
        eval_locations: Float[Array, "n_eval n_d"],
    ) -> Float[Array, "n_eval p"]:
        vmap_apply_kernel = nnx.vmap(
            lambda kernel, x, y: kernel(x, y), in_axes=(0, None, None)
        )

        ### weights are t,1 so we can broadcast
        f_t *= w

        ### shape (q,t,eval_locations)
        diagonal_of_matrix_valued_kernel = vmap_apply_kernel(
            self.lifted_kernel, eval_locations, t
        )

        f_t = f_t.T.reshape(
            -1,
            self.q,
            len(t),
            1,
        )  ### shape (p//q, q, t, 1,)

        ### (q,eval_locations,t) op (p//q,q,t,1) --> (p//q,q,eval_locations,1)
        f_t = jax.vmap(jnp.matmul, in_axes=(None, 0))(
            diagonal_of_matrix_valued_kernel, f_t
        )

        return f_t.reshape(self.p, -1).T


class IntegrateAgainstTridiagonalMVK(nnx.Module):
    def __init__(self, base_kernel: nnx.Module, p: int, *, rngs: nnx.Rngs):
        self.p = p
        self.super_diagonal_kernel = create_lifted_module(base_kernel, p - 1, rngs=rngs)
        self.super_diagonal_indices = jnp.array([jnp.arange(p - 1), jnp.arange(1, p)]).T

        self.sub_diagonal_kernel = create_lifted_module(base_kernel, p - 1, rngs=rngs)
        self.sub_diagonal_indices = jnp.array([jnp.arange(1, p), jnp.arange(p - 1)]).T

        self.diagonal_kernel = create_lifted_module(base_kernel, p, rngs=rngs)
        self.diagonal_indices = jnp.array([jnp.arange(p), jnp.arange(p)]).T

    def __call__(
        self,
        f_t: Float[Array, "n_t p"],
        t: Float[Array, "n_t n_d"],
        w: Float[Array, "n_t 1"],
        eval_locations: Float[Array, "n_eval n_d"],
    ) -> Float[Array, "n_eval p"]:
        ### weights are t,1 so we can broadcast
        f_t *= w

        mvk = jnp.zeros((self.p, self.p, len(eval_locations), len(t)))

        super_diagonal_k = nnx.vmap(
            lambda kernel, x, y: kernel(x, y), in_axes=(0, None, None)
        )(self.super_diagonal_kernel, eval_locations, t)
        mvk = mvk.at[
            self.super_diagonal_indices[:, 0], self.super_diagonal_indices[:, 1]
        ].set(super_diagonal_k)

        sub_diagonal_k = nnx.vmap(
            lambda kernel, x, y: kernel(x, y), in_axes=(0, None, None)
        )(self.sub_diagonal_kernel, eval_locations, t)
        mvk = mvk.at[
            self.sub_diagonal_indices[:, 0], self.sub_diagonal_indices[:, 1]
        ].set(sub_diagonal_k)

        diagonal_k = nnx.vmap(
            lambda kernel, x, y: kernel(x, y), in_axes=(0, None, None)
        )(self.diagonal_kernel, eval_locations, t)
        mvk = mvk.at[self.diagonal_indices[:, 0], self.diagonal_indices[:, 1]].set(
            diagonal_k
        )

        int_f_eval_at_e = jnp.einsum("pqet,tq->pe", mvk, f_t).T
        return int_f_eval_at_e


class IntegrateAgainstDenseMVK(nnx.Module):
    def __init__(self, base_kernel: nnx.Module, p: int, *, rngs: nnx.Rngs):
        self.p = p
        self.dense_kernel = create_lifted_module(base_kernel, p * p, rngs=rngs)

    def __call__(
        self,
        f_t: Float[Array, "n_t p"],
        t: Float[Array, "n_t n_d"],
        w: Float[Array, "n_t 1"],
        eval_locations: Float[Array, "n_eval n_d"],
    ) -> Float[Array, "n_eval p"]:
        ### weights are t,1 so we can broadcast
        f_t *= w

        mvk = nnx.vmap(lambda kernel, x, y: kernel(x, y), in_axes=(0, None, None))(
            self.dense_kernel, eval_locations, t
        ).reshape(self.p, self.p, len(eval_locations), len(t))
        int_f_eval_at_e = jnp.einsum("pqet,tq->pe", mvk, f_t).T
        return int_f_eval_at_e
