"""
Implementation of our model.
"""

import warnings
import chex
import jax
import jax.random as random
import jax.numpy as jnp
import einops as e

import flax.linen as nn
from dynamax.nn import AdditiveESN


class Ours(nn.Module):
    num_nodes: int
    output_dim: int = 22
    projection_scale: float = 0.1
    dynamics_scale: float = 0.8
    spectral_radius: float = 0.9
    sparsity: float = 0.01
    alpha: float = 0.8

    @nn.compact
    def __call__(self, inputs, dynamics_state):
        # learnt projection
        x = nn.Dense(features=self.num_nodes)(inputs)
        x = nn.sigmoid(x)
        x = nn.Dense(features=self.num_nodes)(x)

        # dynamics
        next_state = AdditiveESN(
            features=self.num_nodes,
            projection_scale=self.projection_scale,
            dynamics_scale=self.dynamics_scale,
            spectral_radius=self.spectral_radius,
            sparsity=self.sparsity,
            alpha=self.alpha,
        )(inputs, x, dynamics_state)

        # output projection
        x = nn.Dense(features=self.num_nodes)(next_state)
        x = nn.sigmoid(x)
        x = nn.Dense(features=self.output_dim)(x)

        return x, next_state


def init_our_model(cfg):
    key = random.key(42)
    model = Ours(**cfg["model_params"])
    dummy_input = jnp.ones(cfg["init_shapes"]["input"])
    dummy_state = jnp.ones(cfg["init_shapes"]["state"])
    variables = model.init(key, dummy_input, dummy_state)
    gpus = jax.devices("gpu")
    variables = jax.device_put(variables, gpus[0])
    return model, variables


class OursMultiTask(nn.Module):
    num_nodes: int
    output_dim: int = 22
    projection_scale: float = 0.1
    dynamics_scale: float = 0.8
    spectral_radius: float = 0.9
    sparsity: float = 0.01
    alpha: float = 0.8

    @nn.compact
    def __call__(self, inputs, dynamics_state, image):

        # normalize the image
        image = (image / 127.5) - 1.0

        # patchify the image
        # check if the image is square
        b, h, w, c = image.shape
        chex.assert_equal(h, w)

        # check if the image is divisible by the patch size
        # automatically resize if not
        patch_size = 32
        if h % patch_size != 0:
            warnings.warn("The image is not divisible by the patch size. Automatically resizing image.", stacklevel=1)
            new_dim = h // patch_size
            image = jax.image.resize(image, (new_dim, new_dim), method="nearest")

        # create an array of patches
        patches = e.rearrange(image, "batch (h p1) (w p2) c -> batch (h w) (p1) (p2) (c)", p1=patch_size, p2=patch_size)

        image_embeddings = nn.Conv(features=64, kernel_size=(2, 2), strides=(1, 1), padding="VALID")(patches)
        image_embeddings = image_embeddings.reshape((image_embeddings.shape[0], image_embeddings.shape[1], -1))
        image_embeddings = nn.Dense(512)(image_embeddings)

        state_embedding = jax.nn.sigmoid(nn.Dense(self.num_nodes)(inputs))
        state_embedding = nn.Dense(512)(state_embedding)

        rc_in_learnt = jnp.hstack([image_embeddings, jnp.expand_dims(state_embedding, axis=1)])

        for _ in range(4):
            rc_in_inputs = rc_in_learnt
            rc_in_learnt = nn.LayerNorm()(rc_in_learnt)
            rc_in_learnt = nn.MultiHeadAttention(num_heads=4, qkv_features=512)(rc_in_learnt)
            rc_in_learnt = rc_in_learnt + rc_in_inputs
            rc_in_inputs = rc_in_learnt
            rc_in_learnt = nn.LayerNorm()(rc_in_learnt)
            rc_in_learnt = jax.nn.sigmoid(nn.Dense(512)(rc_in_learnt))
            rc_in_learnt = nn.Dense(512)(rc_in_learnt)
            rc_in_learnt = rc_in_learnt + rc_in_inputs
            rc_in_learnt = nn.LayerNorm()(rc_in_learnt)

        rc_in_learnt = rc_in_learnt.reshape((rc_in_learnt.shape[0], -1))
        rc_in_learnt = jax.nn.sigmoid(nn.Dense(self.num_nodes)(rc_in_learnt))
        rc_in_learnt = nn.Dense(self.num_nodes)(rc_in_learnt)

        # dynamics
        next_state = AdditiveESN(
            features=self.num_nodes,
            projection_scale=self.projection_scale,
            dynamics_scale=self.dynamics_scale,
            spectral_radius=self.spectral_radius,
            sparsity=self.sparsity,
            alpha=self.alpha,
        )(inputs, rc_in_learnt, dynamics_state)

        # output projection
        x = nn.Dense(features=self.num_nodes)(next_state)
        x = nn.sigmoid(x)
        x = nn.Dense(features=self.output_dim)(x)

        return x, next_state


def init_our_multi_model(cfg):
    key = random.key(42)
    model = OursMultiTask(**cfg["model_params"])
    dummy_input = jnp.ones(cfg["init_shapes"]["input"])
    dummy_state = jnp.ones(cfg["init_shapes"]["state"])
    dummy_image = jnp.ones(cfg["init_shapes"]["image"])
    variables = model.init(key, dummy_input, dummy_state, dummy_image)
    gpus = jax.devices("gpu")
    variables = jax.device_put(variables, gpus[0])
    return model, variables
