'''
Implment the hypernetwork (HN) in this file. 
When initializing the HN, it takes a base network as input to initialize its own learnable parameters correspondingly. 
At inference time, the HN takes language instructions as input to generate base network parameters. 
Then the base network takes image observations as inputs to predict actions. 
'''
from typing import Dict
import re

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np

from hypervla.components.base_octo import OctoTransformer
from hypervla.components.transformer import Transformer
from octo.model.components.tokenizers import ImageTokenizer
from octo.model.components.vit_encoders import SmallStem16
from octo.utils.spec import ModuleSpec

from enum import IntEnum

class InitOptions(IntEnum):
    BIAS_INIT = 0
    VARIANCE_INIT = 1

class HyperNetwork(nn.Module):
    base_net_metadata: Dict
    hypernet_kwargs: Dict

    def setup(self):
        # if self.hypernet_kwargs.get("one_hot_context", False):
        #     self.output_head = nn.Dense(
        #         self.base_net_metadata['total_param_num'],
        #         use_bias=self.hypernet_kwargs.get("output_head_bias", True),
        #         kernel_init=nn.initializers.zeros,
        #         bias_init=nn.initializers.normal(stddev=1e-2),
        #         name='output_head'
        #     )

        # project language token embeddings to the context embedding dimension
        self.token_projection = nn.Dense(
            self.hypernet_kwargs["context_embedding_dim"], 
            name="task_token_projection"
        )

        if self.hypernet_kwargs.get("use_initial_image", False):
            self.image_projection = nn.Dense(
                self.hypernet_kwargs["context_embedding_dim"], 
                name="initial_image_projection"
            )

        self.generation_strategy = self.hypernet_kwargs.get("generation_strategy", "full")
        if self.generation_strategy == 'full':
            self.layer_token_num = 1
        elif self.generation_strategy == 'block':
            self.layer_token_num = self.base_net_metadata['block_num']
        # Transformer context encoder
        self.context_encoder = Transformer(
            embedding_dim=self.hypernet_kwargs["context_embedding_dim"],
            **self.hypernet_kwargs["context_encoder_kwargs"]
        )

        self.output_head = {
            name: self._create_output_head(name, head_info) for name, head_info in self.base_net_metadata["output_head_info"].items()
        }

    def _create_output_head(self, path, head_info):
        output_dim, generation_flag, init_strategy, init_variance = head_info["output_dim"], head_info["generation_flag"], head_info["init_strategy"], head_info["init_variance"]

        if init_strategy == InitOptions.VARIANCE_INIT and path.split('_')[-1] != "bias":
            kernel_init = nn.initializers.truncated_normal(stddev=np.sqrt(init_variance))
            bias_init = nn.initializers.zeros
        else:
            kernel_init = nn.initializers.zeros
            bias_init = nn.initializers.zeros
        # generation_flag determines if the base param block is generated by HN or shared across tasks
        if generation_flag:
            head = nn.Dense(
                output_dim,
                use_bias=self.hypernet_kwargs.get("output_head_bias", True),
                kernel_init=kernel_init,
                bias_init=bias_init,
            )
            return head
        else:
            if isinstance(path, str):
                name = path
            else:
                name = '_'.join([p.key for p in path])
            head = self.param(
                name, 
                nn.initializers.truncated_normal(stddev=0.02),
                (output_dim, ),
            )
            return head

    @nn.compact
    def generate_context_embedding(self, tasks, train: bool, initial_states=None):
        '''
        task_tokens: shape = (batch_size * token_num * token_embedding_size)
        '''
        task_tokens = tasks['language_instruction']['token_embedding']
        token_mask = tasks['language_instruction']['attention_mask']
        batch_size, instruction_token_len = task_tokens.shape[0], task_tokens.shape[1]
        # if self.hypernet_kwargs.get("one_hot_context", False):
        #     context_embedding = task_tokens
        #     base_params = self.output_head(context_embedding)

        # projection layer for the task tokens
        task_tokens = self.token_projection(task_tokens)
        # add PE to task tokens
        # TODO: learnable PE or pre-defined?
        task_tokens += self._create_positional_embedding('task', task_tokens)

        # initial image
        if self.hypernet_kwargs.get("use_initial_image", False):
            if self.hypernet_kwargs.get("image_dropout", 0.0) > 0:
                initial_image = nn.Dropout(rate=self.hypernet_kwargs["image_dropout"])(initial_states["patch_embeddings"], deterministic=not train)
            else:
                initial_image = initial_states["patch_embeddings"]
            if self.hypernet_kwargs.get("use_all_image_tokens", False):
                initial_image_tokens = self.image_projection(initial_image)
            else:
                initial_image_tokens = self.image_projection(initial_image[:, :1])
            initial_image_tokens += self._create_positional_embedding("initial_image", initial_image_tokens)
            task_tokens = jnp.concatenate([task_tokens, initial_image_tokens], axis=1)

        # image goal tokens
        if self.hypernet_kwargs.get("include_goal_image", False):
            goal_images = tasks["image_primary"]
            goal_image_tokens = SmallStem16(learnable_norm=False)(goal_images)
            goal_image_tokens = goal_image_tokens.reshape(batch_size, -1, goal_image_tokens.shape[-1])
            goal_image_tokens = nn.Dense(
                self.hypernet_kwargs["context_embedding_dim"], 
                name="goal_image_token_projection",
            )(goal_image_tokens)
            # add PE
            goal_image_tokens += self._create_positional_embedding('goal_image', goal_image_tokens)
            task_tokens = jnp.concatenate([task_tokens, goal_image_tokens], axis=1)

        # layer tokens
        layer_tokens = jnp.zeros((batch_size, self.layer_token_num, self.hypernet_kwargs["context_embedding_dim"]))
        layer_tokens += self._create_positional_embedding('layer', layer_tokens)
        # context input to context encoder
        context_tokens = jnp.concatenate([task_tokens, layer_tokens], axis=1)

        # define attention mask: each row of the mask determines how a token attends to the other tokens
        # 1. determine if padding tokens in the instruction will be attended to (no by default)
        if self.hypernet_kwargs["attend_to_padding"]:
            instruction_attention_mask = jnp.ones((batch_size, 1, context_tokens.shape[-2], instruction_token_len), dtype=bool)
        else:
            instruction_attention_mask = jnp.broadcast_to(jnp.expand_dims(token_mask, (1, 2)), (batch_size, 1, context_tokens.shape[-2], instruction_token_len)).astype(bool)
        # pad mask shape: (batch_size, ) -> (batch_size, 1, context length, language token length)
        instruction_pad_mask = jnp.broadcast_to(tasks["pad_mask_dict"]["language_instruction"][:, None, None, None], (batch_size, 1, context_tokens.shape[1], instruction_token_len)).astype(bool)
        instruction_attention_mask &= instruction_pad_mask
        attention_mask = [instruction_attention_mask]

        # 2. initial image mask
        if self.hypernet_kwargs.get("use_initial_image", False):
            initial_image_mask = jnp.ones((batch_size, 1, context_tokens.shape[-2], initial_image_tokens.shape[-2]), dtype=bool)
            attention_mask.append(initial_image_mask)

        # 3. goal image mask
        if self.hypernet_kwargs.get("include_goal_image", False):
            image_goal_mask = jnp.broadcast_to(tasks["pad_mask_dict"]["image_primary"][:, None, None, None], (batch_size, 1, context_tokens.shape[1], goal_image_tokens.shape[1])).astype(bool)
            attention_mask.append(image_goal_mask)

        # 4. determine if task tokens attend to layer tokens (no by default)
        if "layer_token_mask" in self.base_net_metadata:
            layer_attention_mask = jnp.array(self.base_net_metadata["layer_token_mask"])
            layer_attention_mask = jnp.broadcast_to(layer_attention_mask[None, None, None, :], (batch_size, 1, context_tokens.shape[-2], self.layer_token_num))
        else:
            layer_attention_mask = jnp.ones((batch_size, 1, context_tokens.shape[-2], self.layer_token_num), dtype=bool)
        if not self.hypernet_kwargs["task_attend_to_layer"]:
            layer_attention_mask = layer_attention_mask.at[:, :, :-self.layer_token_num, :].set(False)
        attention_mask.append(layer_attention_mask)

        # concat
        attention_mask = jnp.concatenate(attention_mask, axis=-1)

        # context encoder by transformer
        output, _ = self.context_encoder(
            context_tokens, attention_mask, train=train
        )
        # get the context embedding
        context_embedding = output[:, -self.layer_token_num:]

        # scale the final context embedding
        if self.hypernet_kwargs.get("scale_context_embedding", False):
            context_embedding /= jnp.sqrt(self.hypernet_kwargs["context_embedding_dim"])
        # apply dropout to the final context embedding
        embedding_dropout_rate = self.hypernet_kwargs.get("embedding_dropout_rate", 0.)
        context_embedding = nn.Dropout(rate=embedding_dropout_rate)(context_embedding, deterministic=not train)

        return context_embedding

    def __call__(self, tasks, train: bool, initial_states=None):

        context_embedding = self.generate_context_embedding(tasks, train, initial_states)

        # HN output head
        # allocate the context embedding vector used to generate each base param block
        head_token_embeddings = jax.tree_map(lambda idx: context_embedding[:, idx], self.base_net_metadata['token_index_dict'])
        # base_params = jax.tree_map(lambda head, embedding: head(embedding), self.output_head, head_token_embeddings)
        base_params = jax.tree_util.tree_map_with_path(
            self._call_output_head, 
            # self.output_head, 
            # self.base_net_metadata['output_head_mapping'],  # specify which output head to use
            head_token_embeddings, 
            self.base_net_metadata['generation_flag'],
            jax.tree_map(lambda _: train, head_token_embeddings),
        )

        batch_size = tasks['language_instruction']['token_embedding'].shape[0]
        base_params = jax.tree_map(lambda p, shape: p.reshape(batch_size, *shape), base_params, self.base_net_metadata['param_shape'])

        return base_params, context_embedding

    def _call_output_head(self, path, context_embedding, generation_flag, train: bool):
        layer_name = '_'.join([p.key for p in path])
        if self.hypernet_kwargs.get("share_TF_output_head", False):
            layer_name = re.sub(r'encoderblock_\d+', 'encoderblock', layer_name)
        layer = self.output_head[layer_name]
        if generation_flag:
            generated_params = layer(context_embedding)
            dropout_rate = self.hypernet_kwargs.get("final_dropout_rate", None)
            if dropout_rate is not None:
                generated_params = nn.Dropout(rate=dropout_rate)(generated_params, deterministic=not train)
            return generated_params
        else:
            return jnp.broadcast_to(layer[None], (context_embedding.shape[0], *layer.shape))

    def _create_positional_embedding(self, name: str, tokens: jax.Array):
        shape = (1, *tokens.shape[-2:])
        embedding = self.param(
            f"{name}_pos_embedding",
            nn.initializers.normal(stddev=0.02),
            shape,
        )
        return jnp.broadcast_to(embedding, tokens.shape)
