import typing as tp
from functools import partial

import jax.numpy as jnp
from flax import nnx
from jax.debug import print as jprint  ### noqa
from jaxtyping import Array, Float

from .layers import (
    DiagonalGaussianSpectralMixtureKernel,
    IntegrateAgainstDiagonalMatrixValuedKernel,
    KernelInterpolate,
    KernelInterpolateLeastSquares,
    ProductNNKernel,
    WendlandC4Kernel,
)


activation_mapping = {
    "gelu": nnx.gelu,
    "relu": nnx.relu,
    "tanh": nnx.tanh,
    "leaky_relu": nnx.leaky_relu,
}


def build_model(config, dynamic_config):
    ### pull some variables from config for conciseness

    ndims = dynamic_config["ndims"]
    in_features, out_features = (
        dynamic_config["in_features"],
        dynamic_config["out_features"],
    )
    centers = dynamic_config["centers"]

    p = config.model.p
    q = config.model.q
    activation = activation_mapping[config.model.activation]
    depth = config.model.depth

    assert (p % q == 0) and (q <= p)

    ### need this thing
    rngs = nnx.Rngs(config.seed)

    ########### kernel setup  ##########################################################

    ### return all args for the kernel to partially initialize the class (save for rngs)
    def get_kernel_class_and_init_args(kernel_name: str):
        if kernel_name == "spectral_mixture":
            num_gaussians = config.kernel.spectral_mixture.q
            return DiagonalGaussianSpectralMixtureKernel, {
                "q": num_gaussians,
                "ndims": ndims,
            }

        elif kernel_name == "wendland":
            return WendlandC4Kernel, {}

        else:
            raise Exception("kernel class not found")

    ### kernels set up
    interpolation_kernel, init_args = get_kernel_class_and_init_args(
        config.kernel.interpolation_kernel
    )
    interpolation_kernel = partial(interpolation_kernel, **init_args)

    integration_kernel, init_args = get_kernel_class_and_init_args(
        config.kernel.integration_kernel
    )
    integration_kernel = partial(integration_kernel, **init_args)

    eval_kernel, init_args = get_kernel_class_and_init_args(config.kernel.eval_kernel)
    eval_kernel = partial(eval_kernel, **init_args)

    ### overparameterize until we die
    if config.kernel.add_product_kernel:
        interpolation_kernel = partial(
            ProductNNKernel, base_kernel=interpolation_kernel, ndims=ndims
        )

        integration_kernel = partial(
            ProductNNKernel, base_kernel=integration_kernel, ndims=ndims
        )
        eval_kernel = partial(ProductNNKernel, base_kernel=eval_kernel, ndims=ndims)

    ### interpolate thing #################################################################################
    if config.model.interpolation_scheme == "kernel_least_squares":
        interpolate_module = KernelInterpolateLeastSquares(
            interpolation_kernel, centers, rngs=rngs
        )
    else:
        interpolate_module = KernelInterpolate(interpolation_kernel, rngs=rngs)

    # ### integrate against kernel things ####################################################################
    integration_modules = [
        IntegrateAgainstDiagonalMatrixValuedKernel(integration_kernel, p, q, rngs=rngs)
        for _ in range(depth)
    ]

    # integration_modules = [
    #     IntegrateAgainstTridiagonalMVK(integration_kernel, p, rngs=rngs)
    #     for _ in range(depth)
    # ]

    # integration_modules = [
    #     IntegrateAgainstDenseMVK(integration_kernel, p, rngs=rngs)
    #     for _ in range(depth)
    # ]
    pointwise_convs = [
        nnx.Conv(p, p, kernel_size=(1,), rngs=rngs) for _ in range(depth)
    ]

    last_integration_module = IntegrateAgainstDiagonalMatrixValuedKernel(
        eval_kernel, p, q, rngs=rngs
    )

    integration_modules = [*integration_modules, last_integration_module]

    ### lift thing #######################################################################################

    lift_module = []
    lift_module.append(nnx.Linear(in_features, p, rngs=rngs))
    lift_module.append(activation)

    ### project thing ####################################################################################

    proj_module = []
    proj_module.append(nnx.Linear(p, p, rngs=rngs))
    proj_module.append(activation)
    proj_module.append(nnx.Linear(p, p, rngs=rngs))
    proj_module.append(activation)
    proj_module.append(nnx.Linear(p, out_features, rngs=rngs))

    ### final model #######################################################################################

    model = KNO(
        interpolate_module,
        lift_module,
        integration_modules,
        pointwise_convs,
        proj_module,
        activation=activation,
        rngs=rngs,
    )

    return model


class KNO(nnx.Module):
    def __init__(
        self,
        interpolation_module: nnx.Module,
        lift_module: tp.List[nnx.Module],
        integration_modules: tp.List[nnx.Module],
        pointwise_convs: tp.List[nnx.Module],
        proj_module: tp.List[nnx.Module],
        activation: tp.Callable,
        *,
        rngs: nnx.Rngs,
    ):
        self.interpolation_module = interpolation_module
        self.lift_module = lift_module
        self.integration_modules = integration_modules
        self.pointwise_convs = pointwise_convs
        self.proj_module = proj_module
        self.activation = activation

    def __call__(
        self,
        f_x: Float[Array, "n_x n_func_dims"],
        quadrature_nodes: Float[Array, "n_t n_dims"],
        quadrature_weights: Float[Array, "n_t 1"],
        x_grid: Float[Array, "n_x n_dims"],
        y_grid: Float[Array, "n_y n_dims"],
    ) -> Float[Array, "n_y n_func_dims"]:
        ### doesn't really speed up anything to do this
        # t,w = stop_gradient(quadrature_nodes), stop_gradient(quadrature_weights)
        # x_grid, y_grid = stop_gradient(x_grid), stop_gradient(y_grid)

        t, w = quadrature_nodes, quadrature_weights

        ### interpolate to quadrature nodes
        f_t = self.interpolation_module(f_x, x_grid, t)

        ### concatenate grid + lift
        f_t = jnp.concatenate((f_t, t), axis=-1)

        for layer in self.lift_module:
            f_t = layer(f_t)

        ### do the thing
        for i, integration_module in enumerate(self.integration_modules[:-1]):
            f_t_in = f_t
            f_t = integration_module(f_t, t, w, t)
            f_t = self.activation(f_t + self.pointwise_convs[i](f_t_in))

        ### evaluate last integral at y_grid
        f_y = self.integration_modules[-1](f_t, t, w, y_grid)
        f_y = self.activation(f_y)

        ### project
        for layer in self.proj_module:
            f_y = layer(f_y)

        return f_y
