"""
Implementation of Echo State Network.
"""

import jax
import jax.random as random
import jax.numpy as jnp
import einops as e
import flax.linen as nn
from dynamax.nn import ESN


class EchoStateNetwork(nn.Module):
    num_nodes: int
    output_dim: int = 22
    projection_scale: float = 0.1
    dynamics_scale: float = 0.8
    spectral_radius: float = 0.9
    sparsity: float = 0.01
    alpha: float = 0.8

    @nn.compact
    def __call__(self, inputs, dynamics_state):
        next_state = ESN(
            features=self.num_nodes,
            projection_scale=self.projection_scale,
            dynamics_scale=self.dynamics_scale,
            spectral_radius=self.spectral_radius,
            sparsity=self.sparsity,
            alpha=self.alpha,
        )(inputs, dynamics_state)

        # output projection
        x = nn.Dense(features=self.num_nodes)(next_state)
        x = nn.sigmoid(x)
        x = nn.Dense(features=self.output_dim)(x)

        return x, next_state


def init_esn_model(cfg):
    key = random.key(42)
    model = EchoStateNetwork(**cfg["model_params"])
    dummy_input = jnp.ones(cfg["init_shapes"]["input"])
    dummy_state = jnp.ones(cfg["init_shapes"]["state"])
    variables = model.init(key, dummy_input, dummy_state)
    gpus = jax.devices("gpu")
    variables = jax.device_put(variables, gpus[0])
    return model, variables


class ESNMultiTask(nn.Module):
    num_nodes: int
    output_dim: int = 22
    projection_scale: float = 0.1
    dynamics_scale: float = 0.8
    spectral_radius: float = 0.9
    sparsity: float = 0.01
    alpha: float = 0.8

    @nn.compact
    def __call__(self, inputs, dynamics_state, image):

        # 1. normalize the image
        image = (image / 127.5) - 1.0

        # 2. input convolution projection
        image = nn.Conv(
            features=32,
            kernel_size=(6, 6),
            strides=(3, 3),
            padding="VALID",
        )(image)
        for _ in range(2):
            residual = image
            image = nn.Conv(
                features=32,
                kernel_size=(3, 3),
                strides=(3, 3),
                padding="VALID",
            )(image)
            image = nn.LayerNorm()(image)
            image = nn.gelu(image)
            image = nn.Conv(
                features=32,
                kernel_size=(3, 3),
                strides=(1, 1),
                padding="SAME",
            )(image)
            image = nn.LayerNorm()(image)

            if residual.shape != image.shape:
                residual = nn.Conv(
                    features=32,
                    kernel_size=(3, 3),
                    strides=(3, 3),
                    padding="VALID",
                )(residual)

            image = nn.gelu(image + residual)

        image = e.rearrange(image, "batch filters h w -> batch (filters h w)")
        image = nn.Dense(1024)(image)

        next_state = ESN(
            features=self.num_nodes,
            projection_scale=self.projection_scale,
            dynamics_scale=self.dynamics_scale,
            spectral_radius=self.spectral_radius,
            sparsity=self.sparsity,
            alpha=self.alpha,
        )(jnp.hstack([inputs, image]), dynamics_state)

        # output projection
        x = nn.Dense(features=self.num_nodes)(next_state)
        x = nn.sigmoid(x)
        x = nn.Dense(features=self.output_dim)(x)

        return x, next_state


def init_esn_multi_model(cfg):
    key = random.key(42)
    model = ESNMultiTask(**cfg["model_params"])
    dummy_input = jnp.ones(cfg["init_shapes"]["input"])
    dummy_state = jnp.ones(cfg["init_shapes"]["state"])
    dummy_image = jnp.ones(cfg["init_shapes"]["image"])
    variables = model.init(key, dummy_input, dummy_state, dummy_image)
    gpus = jax.devices("gpu")
    variables = jax.device_put(variables, gpus[0])
    return model, variables
