from typing import Optional, Tuple, List
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import lax
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.gpt2.modeling_flax_gpt2 import (
    FlaxGPT2Attention,
    FlaxGPT2MLP,
    FlaxGPT2Block,
    FlaxGPT2PreTrainedModel,
    GPT2Config, 
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
)
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from typing import Callable, Literal

def print_model(flax_params, file=None):
    flat_params = flatten_dict(flax_params)
    for path, value in flat_params.items(): 
        name = "/".join(path)
        line = f"{name} {value.shape}"
        if file:
            print(line, file=file)
        else:
            print(line)
class FlaxGPT2MoE(nn.Module):
    config: GPT2Config
    intermediate_size: int
    dtype: jnp.dtype = jnp.float32


    def setup(self):
        self.embed_dim = self.config.hidden_size
        # Fetch MoE-related config attributes
        self.num_routed_experts = getattr(self.config, "num_routed_experts")
        self.num_shared_experts = getattr(self.config, "num_shared_experts")
        self.topk = getattr(self.config, "topk")
        self.num_experts = self.num_routed_experts
        # Setup gated (routed) experts
        self.routed_experts = [
            FlaxGPT2MLP(config=self.config, intermediate_size=self.intermediate_size, dtype=self.dtype)
            for _ in range(self.num_routed_experts)
        ]
        # Optional shared experts (ungated)
        if self.num_shared_experts > 0:
            self.shared_experts = FlaxGPT2MLP(config=self.config, intermediate_size=self.intermediate_size*self.num_shared_experts, dtype=self.dtype)
        else:
            self.shared_experts = None
        # Gating network (only for routed experts)
        self.gate = nn.Dense(self.num_routed_experts, dtype=self.dtype)

    def __call__(self, hidden_states, deterministic: bool = True,flags:dict = None,):
        B, T, D = hidden_states.shape  # [batch, seq, dim]

        # --- Gating ---
        gate_logits = self.gate(hidden_states)  # [B, T, num_experts]
        topk_values, topk_indices = jax.lax.top_k(gate_logits, self.topk)

        gate_weights = jnp.zeros_like(gate_logits)
        gate_weights = gate_weights.at[
            jnp.arange(B)[:, None, None],
            jnp.arange(T)[None, :, None],
            topk_indices
        ].set(nn.softmax(topk_values, axis=-1))  # [B, T, num_experts]

        # --- Routed Expert Computation ---
        routed_expert_outputs = [
            expert(hidden_states, deterministic=deterministic)
            for expert in self.routed_experts
        ]  # List of [B, T, D]
        routed_expert_outputs = jnp.stack(routed_expert_outputs, axis=-2)  # [B, T, num_experts, D]
        gate_weights = gate_weights[..., None]  # [B, T, num_experts, 1]
        routed_outputs = jnp.sum(routed_expert_outputs * gate_weights, axis=-2)  # [B, T, D]
        # --- Shared Expert Computation ---
        if self.shared_experts is None:
            shared_outputs = 0
        else:
            shared_outputs = self.shared_experts(hidden_states, deterministic=deterministic)
        if(flags is not None): 
            flags["expert_outs"] = routed_expert_outputs
            flags["gate_logits"] = gate_logits
        return shared_outputs + routed_outputs


class FlaxGPT2MoEBlock(nn.Module):
    config: GPT2Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        hidden_size = self.config.hidden_size
        inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size

        self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
        self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
        self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)

        if self.config.add_cross_attention:
            self.crossattention = FlaxGPT2Attention(
                config=self.config, dtype=self.dtype, causal=False, is_cross_attention=True
            )
            self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)

        self.mlp = FlaxGPT2MoE(config=self.config,intermediate_size=inner_dim,dtype=self.dtype,)
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        flags:dict = None,
    ):
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            attention_mask=attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]
        outputs = attn_outputs[1:]

        hidden_states = residual + attn_output

        if encoder_hidden_states is not None:
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} must be instantiated with "
                    "`config.add_cross_attention=True`"
                )
            residual = hidden_states
            hidden_states = self.ln_cross_attn(hidden_states)
            cross_attn_outputs = self.crossattention(
                hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                deterministic=deterministic,
                output_attentions=output_attentions,
            )
            cross_attn_output = cross_attn_outputs[0]
            hidden_states = residual + cross_attn_output
            outputs = outputs + cross_attn_outputs[1:]

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        ff_output = self.mlp(hidden_states, deterministic=deterministic,flags = flags,)
        hidden_states = residual + ff_output

        return (hidden_states,) + outputs
class FlaxGPT2MoEBlockCollection(nn.Module):
    config: GPT2Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # Default to no MoE layers if not specified
        moe_layer_indices = getattr(self.config, "moe_layer_indices", 0)

        self.blocks = [
            FlaxGPT2MoEBlock(self.config, name=str(i), dtype=self.dtype) if i == moe_layer_indices
            else FlaxGPT2Block(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.num_hidden_layers)
        ]

    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        flags: dict = None,    
    ):
        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

        for i, block in enumerate(self.blocks):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            if i == self.config.moe_layer_indices:
                layer_outputs = block(
                    hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    deterministic=deterministic,
                    init_cache=init_cache,
                    output_attentions=output_attentions,
                    flags = flags,    
                )
            else:
                layer_outputs = block(
                    hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    deterministic=deterministic,
                    init_cache=init_cache,
                    output_attentions=output_attentions,
                )
            hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions += (layer_outputs[1],)

                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
        return outputs
    
class FlaxGPT2MoEModule(nn.Module):
    config: GPT2Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.embed_dim = self.config.hidden_size

        self.wte = nn.Embed(
            num_embeddings=self.config.vocab_size,
            features=self.embed_dim,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )
        self.wpe = nn.Embed(
            num_embeddings=self.config.max_position_embeddings,
            features=self.embed_dim,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
            dtype=self.dtype,
        )

        self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
        self.h = FlaxGPT2MoEBlockCollection(self.config, dtype=self.dtype)
        self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        flags: dict = None,
    ):
        # Embedding
        input_embeds = self.wte(input_ids.astype("i4"))
        position_embeds = self.wpe(position_ids.astype("i4"))
        hidden_states = input_embeds + position_embeds
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)

        # Transformer stack
        outputs = self.h(
            hidden_states,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            flags = flags,
        )

        # Final layer norm
        hidden_states = outputs[0]
        hidden_states = self.ln_f(hidden_states)

        # Collect outputs
        if output_hidden_states:
            all_hidden_states = outputs[1] + (hidden_states,)
            outputs = (hidden_states, all_hidden_states) + outputs[2:]
        else:
            outputs = (hidden_states,) + outputs[1:]

        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=outputs[1],
            attentions=outputs[2],
            cross_attentions=outputs[3],
        )
class FlaxGPT2MoEModel(FlaxGPT2PreTrainedModel):
    module_class = FlaxGPT2MoEModule
class FlaxGPT2MoELMHeadModule(nn.Module):
    config: GPT2Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # Use the MoE-enhanced GPT2 module
        self.transformer = FlaxGPT2MoEModule(self.config, dtype=self.dtype)

        self.lm_head = nn.Dense(
            self.config.vocab_size,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        flags: dict = None,
    ):
        # Forward pass through MoE-based transformer
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            flags = flags,
        )

        hidden_states = outputs[0]

        # Use tied embeddings if configured
        if self.config.tie_word_embeddings:
            shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
        else:
            lm_logits = self.lm_head(hidden_states)

        if not return_dict:
            return (lm_logits,) + outputs[1:]

        return FlaxCausalLMOutputWithCrossAttentions(
            logits=lm_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )
class FlaxGPT2MoELMHeadModel(FlaxGPT2PreTrainedModel):
    module_class = FlaxGPT2MoELMHeadModule  # use the MoE-based LM module

    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        batch_size, seq_length = input_ids.shape

        # Initialize cache for fast generation
        past_key_values = self.init_cache(batch_size, max_length)

        # Efficient static attention mask for causal model
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(
                extended_attention_mask, attention_mask.astype("i4"), (0, 0)
            )
        else:
            position_ids = jnp.broadcast_to(
                jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
            )

        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs
