"""
Implementations of neural network modules inspired by Echo State Networks.
"""

from typing import Callable
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen import initializers
from flax.typing import (
    Array,
    PRNGKey as PRNGKey,
    Dtype,
    Shape as Shape,
)

import dynamax.nn.initializers as dyn_initializers


class ESN(nn.Module):
    """
    Echo State Network.
    """

    features: int
    projection_scale: float = 0.1
    dynamics_scale: float = 0.8
    spectral_radius: float = 0.9
    sparsity: float = 0.01
    alpha: float = 0.8
    param_dtype: Dtype = jnp.float32
    dynamics_init: Callable = dyn_initializers.random_uniform_dynamics
    projection_init: Callable = initializers.uniform

    @nn.compact
    def __call__(self, inputs: Array, dynamics_state: Array) -> Array:
        """
        Placeholder for the forward pass of the AdditiveESN module.
        """
        projection = self.variable(
            "dynamics_params",
            "input_projection",
            lambda num_nodes, input_dim: self.projection_init(self.projection_scale)(
                self.make_rng("params"), (input_dim, num_nodes)
            ),
            self.features,
            jnp.shape(inputs)[-1],  # consider moving to ndim
        )

        dynamics = self.variable(
            "dynamics_params",
            "network_dynamics",
            lambda num_nodes, scale, spectral_radius, sparsity: self.dynamics_init(
                self.make_rng("params"), num_nodes, scale, spectral_radius, sparsity
            ),
            self.features,
            self.dynamics_scale,
            self.spectral_radius,
            self.sparsity,
        )

        # input projections
        input_projection = jax.lax.dot_general(
            inputs,
            projection.value,
            (((inputs.ndim - 1,), (0,)), ((), ())),
        )

        # dynamics
        dynamics_state_updated = jax.lax.dot_general(
            dynamics_state,
            dynamics.value,
            (((dynamics_state.ndim - 1,), (0,)), ((), ())),
        )

        # apply non-linearity
        x_tilde = jax.nn.tanh(input_projection + dynamics_state_updated)

        # apply leaky update rule
        next_state = ((1 - self.alpha) * dynamics_state) + (self.alpha * x_tilde)

        return next_state


class AdditiveESN(nn.Module):
    """
    Placeholder for the AdditiveESN module.
    """

    features: int
    projection_scale: float = 0.1
    dynamics_scale: float = 0.8
    spectral_radius: float = 0.9
    sparsity: float = 0.01
    alpha: float = 0.8
    param_dtype: Dtype = jnp.float32
    dynamics_init: Callable = dyn_initializers.random_uniform_dynamics
    projection_init: Callable = initializers.uniform

    @nn.compact
    def __call__(self, inputs: Array, embeddings: Array, dynamics_state: Array) -> Array:
        """
        Placeholder for the forward pass of the AdditiveESN module.
        """
        projection = self.variable(
            "dynamics_params",
            "input_projection",
            lambda num_nodes, input_dim: self.projection_init(self.projection_scale)(
                self.make_rng("params"), (input_dim, num_nodes)
            ),
            self.features,
            jnp.shape(inputs)[-1],  # consider moving to ndim
        )

        dynamics = self.variable(
            "dynamics_params",
            "network_dynamics",
            lambda num_nodes, scale, spectral_radius, sparsity: self.dynamics_init(
                self.make_rng("params"), num_nodes, scale, spectral_radius, sparsity
            ),
            self.features,
            self.dynamics_scale,
            self.spectral_radius,
            self.sparsity,
        )

        # input projections
        input_projection = jax.lax.dot_general(
            inputs,
            projection.value,
            (((inputs.ndim - 1,), (0,)), ((), ())),
        )

        # dynamics
        dynamics_state_updated = jax.lax.dot_general(
            dynamics_state,
            dynamics.value,
            (((dynamics_state.ndim - 1,), (0,)), ((), ())),
        )

        # apply non-linearity
        x_tilde = jax.nn.tanh(input_projection + embeddings + dynamics_state_updated)

        # apply leaky update rule
        next_state = ((1 - self.alpha) * dynamics_state) + (self.alpha * x_tilde)

        return next_state
