"""
Feedforward architecture that mimics ours without the dynamics components.
"""

import warnings
import chex
import einops as e
import jax
import jax.random as random
import jax.numpy as jnp
import flax.linen as nn


class FeedForward(nn.Module):
    features: int
    output_dim: int = 22

    @nn.compact
    def __call__(self, inputs):
        # learnt projection
        x = nn.Dense(features=self.features)(inputs)
        x = nn.sigmoid(x)
        x = nn.Dense(features=self.features)(x)
        x = nn.sigmoid(x)
        x = nn.Dense(features=self.features)(x)
        x = nn.sigmoid(x)
        x = nn.Dense(features=self.output_dim)(x)

        return x


def init_feedforward_model(cfg):
    key = random.key(42)
    model = FeedForward(**cfg["model_params"])
    dummy_input = jnp.ones(cfg["init_shapes"]["input"])
    variables = model.init(key, dummy_input)
    return model, variables


class FeedForwardMultiTask(nn.Module):
    features: int
    output_dim: int = 22

    @nn.compact
    def __call__(self, inputs, image):

        # normalize the image
        image = (image / 127.5) - 1.0

        # patchify the image
        # check if the image is square
        b, h, w, c = image.shape
        chex.assert_equal(h, w)

        # check if the image is divisible by the patch size
        # automatically resize if not
        patch_size = 32
        if h % patch_size != 0:
            warnings.warn("The image is not divisible by the patch size. Automatically resizing image.", stacklevel=1)
            new_dim = h // patch_size
            image = jax.image.resize(image, (new_dim, new_dim), method="nearest")

        # create an array of patches
        patches = e.rearrange(image, "batch (h p1) (w p2) c -> batch (h w) (p1) (p2) (c)", p1=patch_size, p2=patch_size)

        image_embeddings = nn.Conv(features=64, kernel_size=(2, 2), strides=(1, 1), padding="VALID")(patches)
        image_embeddings = image_embeddings.reshape((image_embeddings.shape[0], image_embeddings.shape[1], -1))
        image_embeddings = nn.Dense(1024)(image_embeddings)

        state_embedding = jax.nn.sigmoid(nn.Dense(self.features)(inputs))
        state_embedding = nn.Dense(1024)(state_embedding)

        combined_embedding = jnp.hstack([image_embeddings, jnp.expand_dims(state_embedding, axis=1)])

        for _ in range(4):
            combined_embedding_inputs = combined_embedding
            combined_embedding = nn.LayerNorm()(combined_embedding)
            combined_embedding = nn.MultiHeadAttention(num_heads=4, qkv_features=512)(combined_embedding)
            combined_embedding = combined_embedding + combined_embedding_inputs
            combined_embedding_inputs = combined_embedding
            combined_embedding = nn.LayerNorm()(combined_embedding)
            combined_embedding = jax.nn.sigmoid(nn.Dense(1024)(combined_embedding))
            combined_embedding = nn.Dense(1024)(combined_embedding)
            combined_embedding = combined_embedding + combined_embedding_inputs
            combined_embedding = nn.LayerNorm()(combined_embedding)

        combined_embedding = combined_embedding.reshape((combined_embedding.shape[0], -1))
        combined_embedding = jax.nn.sigmoid(nn.Dense(self.features)(combined_embedding))
        combined_embedding = nn.Dense(self.features)(combined_embedding)
        combined_embedding = jax.nn.sigmoid(combined_embedding)

        x = nn.Dense(features=self.features)(combined_embedding)
        x = nn.sigmoid(x)
        x = nn.Dense(features=self.output_dim)(x)

        return x


def init_feedforward_multi_model(cfg):
    key = random.key(42)
    model = FeedForwardMultiTask(**cfg["model_params"])
    dummy_input = jnp.ones(cfg["init_shapes"]["input"])
    dummy_image = jnp.ones(cfg["init_shapes"]["image"])
    variables = model.init(key, dummy_input, dummy_image)
    gpus = jax.devices("gpu")
    variables = jax.device_put(variables, gpus[0])
    return model, variables
