from typing import Sequence, Optional
import distrax
import jax
import jax.numpy as jnp
import flax.linen as nn

from utils.networks import default_init, MLP, LengthNormalize


class MixerBlock(nn.Module):
    num_tokens: int
    embed_dim: int
    hidden_dim_tokens: int
    hidden_dim_channels: int
    init_scale: float = 1e-2

    decay_alpha: float = 0.9

    def setup(self):
        self.token_dense1 = nn.Dense(self.hidden_dim_tokens, kernel_init=default_init())
        self.token_dense2 = nn.Dense(self.num_tokens, kernel_init=default_init())
        self.channel_dense1 = nn.Dense(self.hidden_dim_channels, kernel_init=default_init())
        self.channel_dense2 = nn.Dense(self.embed_dim, kernel_init=default_init())

        # Initialize learnable lower-triangular weight matrix
        self.tm_weights = self.param(
            'tm_weights',
            nn.initializers.normal(stddev=0.02),
            (self.num_tokens, self.num_tokens)
        )
        # Apply lower-triangular mask (prevent contributions from future tokens)
        self.tm_weights = jnp.tril(self.tm_weights)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # x: (B, num_tokens, embed_dim)
        
        # Transpose for token mixing across the token dimension.
        y = jnp.transpose(x, (0, 2, 1))
        y = self.token_dense1(y)
        y = nn.gelu(y)
        y = self.token_dense2(y)
        y = jnp.transpose(y, (0, 2, 1))

        y = jnp.einsum('btd,ts->bsd', y, self.tm_weights)

        x = x + y  # residual connection

        # Channel mixing within each token.
        z = self.channel_dense1(x)
        z = nn.gelu(z)
        z = self.channel_dense2(z)

        output = x + z

        return output
    
class HierarchicalPlannerNetwork(nn.Module):
    """
    Attributes:
        num_tokens: Number of state tokens (sequence length).
        state_dim: Dimension of state features.
        num_action_dims: Total number of action dimensions.
        joint_embed_dim: Joint embedding dimension for mapping actions.
        num_mixer_blocks: Number of MixerBlock layers.
        mixer_token_hidden: Hidden dimension for token mixing.
        mixer_channel_hidden: Hidden dimension for channel mixing.
    """
    num_tokens: int
    state_dim: int
    num_action_dims: int
    joint_embed_dim: int = 128
    num_mixer_blocks: int = 2
    mixer_token_hidden: int = 64
    mixer_channel_hidden: int = 64
    gc_encoder: nn.Module = None
    layer_norm: bool = True
    final_fc_init_scale: float = 1e-2

    high_actor_head: nn.Module = None
    low_actor_head: nn.Module = None
    enc_hidden: Sequence[int] = (128, 128)

    num_subgoals: int = 1

    def setup(self):
        # Parameter for previous token embeddings:
        self.prev_tokens = self.param("prev_tokens",
                                      nn.initializers.normal(stddev=0.1),
                                      (1, self.num_subgoals + 1, self.state_dim))

        self.mixer_blocks = [MixerBlock(num_tokens=self.num_subgoals + 3,
                       embed_dim=self.state_dim,
                       hidden_dim_tokens=self.mixer_token_hidden,
                       hidden_dim_channels=self.mixer_channel_hidden)
                       for _ in range(self.num_mixer_blocks)]
        
        feature_embed = [MLP(hidden_dims=(*self.enc_hidden, self.state_dim), activate_final=False, layer_norm=True)]
        feature_embed.append(LengthNormalize())
        self.feature_embed = nn.Sequential(feature_embed)
        
    def __call__(self,
                 observations: jnp.ndarray,
                 goals: jnp.ndarray,
                 seed: int = None,
                 subgoal_reps: Optional[jnp.ndarray] = None,
                 action_seq: Optional[jnp.ndarray] = None,
                 temperature: float = 1.0):
        
        high_seed, low_seed = jax.random.split(seed)

        observations = jnp.expand_dims(observations, axis=1) # (B, 1, state_dim)
        if goals is not None:
            goals = jnp.expand_dims(goals, axis=1) # (B, 1, state_dim)
                
        if self.gc_encoder is not None:
            features = self.gc_encoder(observations, goals, goal_encoded=False, listwise=True)
            obs_feature = self.feature_embed(features[0])
            goal_feature = features[1]
            features = jnp.concatenate([obs_feature, goal_feature], axis=1)
        else:
            features = [self.feature_embed(observations)]
            if goals is not None:
                features.append(self.feature_embed(goals))            
            features = jnp.concatenate(features, axis=1)
        
        B, T, _ = features.shape
        
        # Repeat the prev_tokens.      
        predicted_subgoals = jnp.zeros((B, self.num_subgoals, self.joint_embed_dim), dtype=jnp.float32)  
        prev_embed_tokens = jnp.tile(self.prev_tokens, (B, 1, 1))

        high_dist_list = []
        for token_dim in range(self.num_subgoals + 1):
            if token_dim == 0:
                prev_embeds = prev_embed_tokens
            
            else:
                if subgoal_reps is not None:
                    prev_embeds = jnp.concatenate([subgoal_reps[:, :token_dim, :], prev_embed_tokens[:, token_dim:, :]], axis=1)
                else:
                    prev_embeds = jnp.concatenate([predicted_subgoals[:, :token_dim, :], prev_embed_tokens[:, token_dim:, :]], axis=1)
            
            x = jnp.concatenate([features, prev_embeds], axis=1)

            target_dim = features.shape[1] + token_dim + 1

            # Apply Mixer blocks.
            for mixer_block in self.mixer_blocks:
                x = mixer_block(x)
            
            target_token = x[:, target_dim-1, :]

            if token_dim < self.num_subgoals:
                high_dist = self.high_actor_head(target_token, temperature=temperature)
                high_dist_list.append(high_dist)
                goal_reps = high_dist.sample(seed=high_seed)

                predicted_subgoals = predicted_subgoals.at[:, token_dim, :].set(goal_reps)

            else:
                low_dist = self.low_actor_head(target_token, temperature=temperature)  #(B, 1)
                predicted_actions = low_dist.sample(seed=low_seed)  # (B, num_action_dims)
                
        return high_dist_list, low_dist, predicted_actions