import chex
import flax.linen as nn
import jax
import jax.numpy as jnp

from tabular_mvdrl.utils.support_init import SupportInitializer


class EWPModel(nn.Module):
    num_states: int
    reward_dim: int
    num_atoms: int
    initializer: jax.nn.initializers.Initializer = nn.linear.default_embed_init

    def init_with_support(
        self,
        rng: chex.PRNGKey,
        data: chex.Array,
        support_initializer: SupportInitializer,
    ) -> chex.ArrayTree:
        dummy_params = self.init(rng, data)
        support_map = support_initializer(rng)
        params, structure = jax.tree_util.tree_flatten(dummy_params)
        return jax.tree_util.tree_unflatten(
            structure, (jnp.reshape(support_map, params[0].shape),)
        )

    @nn.compact
    def __call__(self, i):
        embedding = nn.Embed(
            self.num_states,
            self.reward_dim * self.num_atoms,
            embedding_init=self.initializer,
        )(i)
        return jnp.squeeze(
            jnp.reshape(embedding, (-1, self.num_atoms, self.reward_dim))
        )


class TabularProbabilityModel(nn.Module):
    num_states: int
    num_atoms: int
    logits: bool = True
    initializer: jax.nn.initializers.Initializer = nn.initializers.zeros

    @nn.compact
    def __call__(self, i):
        embedding = nn.Embed(
            self.num_states,
            self.num_atoms,
            embedding_init=self.initializer,
        )(i)
        logits = jnp.squeeze(jnp.reshape(embedding, (-1, self.num_atoms)))
        if self.logits:
            return jax.nn.softmax(logits, axis=-1)
        return logits


# class CategoricalModel(nn.Module):
#     num_states: int
#     reward_dim: int
#     loc_initializer: jax.nn.initializers.Initializer = nn.linear.default_embed_init
#     prob_initializer: jax.nn.initializers.Initializer = nn.linear.default_embed_init

#     @nn.compact
#     def __call__(self, i):
#         locs = nn.Embed(
#             self.num_states, self.reward_dim, embedding_init=self.loc_initializer
#         )(i)
#         probs = nn.Embed(self.num_states, 1, embedding_init=self.prob_intializer)
