from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Union

import haiku as hk
import jax
import jax.numpy as jnp
from chex import Array

from memento.environments.cvrp.types import Observation as CVRPObservation
from memento.environments.knapsack.types import \
    Observation as KnapsackObservation
from memento.environments.tsp.types import Observation as TSPObservation

if TYPE_CHECKING:
    from dataclasses import dataclass

else:
    from chex import dataclass


@dataclass
class Networks:  # type: ignore
    encoder_fn: hk.Transformed
    decoder_fn: hk.Transformed




class DummyDecoderBase(ABC, hk.Module):
    """
    Decoder module.

    This decoder has an additional layer that uses an external memory to retrieve
    data and uses this data to update the context vector.
    """

    def __init__(
        self,
        num_heads,
        key_size,
        model_size=128,
        name="decoder",
    ):
        super().__init__(name=name)
        self.num_heads = num_heads
        self.key_size = key_size
        self.model_size = model_size

    def __call__(
        self,
        observation: Union[TSPObservation, KnapsackObservation, CVRPObservation],
        embeddings: Array,
        dummy_matrix: Array,
    ) -> Array:
        context = self.get_context(observation, embeddings)
        mha = hk.MultiHeadAttention(
            num_heads=self.num_heads,
            key_size=self.key_size,
            model_size=self.model_size,
            w_init_scale=1,
            name="mha_dec",
        )

        attention_mask = jnp.expand_dims(observation.action_mask, (0, 1))
        context = mha(
            query=context,
            key=embeddings,
            value=embeddings,
            mask=self.get_transformed_attention_mask(attention_mask),
        )  # --> [128]

        # create a new layer of multi head attention
        mha = hk.MultiHeadAttention(
            num_heads=self.num_heads,
            key_size=self.key_size,
            model_size=self.model_size,
            w_init_scale=1,
            name="mha_dec_memory",
        )

        # use the retrieved keys and values to update the context - w/ skip connection
        memory_context = mha(
            query=context,
            key=dummy_matrix,
            value=dummy_matrix,
            # mask=self.get_transformed_attention_mask(attention_mask),
        )  # --> [128]

        new_context = context + memory_context

        attn_logits = (
            embeddings @ new_context.squeeze() / jnp.sqrt(self.model_size)
        )  # --> [num_cities/items]
        attn_logits = 10 * jnp.tanh(attn_logits)  # clip to [-10,10]

        return attn_logits, context

    @abstractmethod
    def get_context(
        self,
        observation: Union[TSPObservation, KnapsackObservation, CVRPObservation],
        embeddings: Array,
    ) -> Array:
        pass

    @abstractmethod
    def get_transformed_attention_mask(self, attention_mask: Array) -> Array:
        pass


class TSPDummyDecoder(DummyDecoderBase):
    def get_context(self, observation: TSPObservation, embeddings: Array) -> Array:  # type: ignore[override]
        return jnp.concatenate(
            [
                embeddings.mean(0),
                embeddings[observation.position],
                embeddings[observation.start_position],
            ],
            axis=0,
        )[
            None
        ]  # [1, 3*128=384,]

    def get_transformed_attention_mask(self, attention_mask: Array) -> Array:
        return attention_mask