from typing import Optional, Tuple, List
import flax.linen as nn
import jax
import copy 
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.vit.modeling_flax_vit import (
    FlaxViTPreTrainedModel,
    ViTConfig,
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPooling,
    FlaxViTEmbeddings,
    FlaxViTPooler,
    ACT2FN,
    FlaxViTAttention,
    FlaxViTLayer,
    FlaxViTIntermediate,
    FlaxViTOutput,
    FlaxSequenceClassifierOutput,
)
from typing import Callable

def print_model(flax_params, file=None):
    flat_params = flatten_dict(flax_params)
    for path, value in flat_params.items(): 
        name = "/".join(path)
        line = f"{name} {value.shape}"
        if file:
            print(line, file=file)
        else:
            print(line)
def print_model_with_prefix(flax_params, prefix: str, file=None):
    flat_params = flatten_dict(flax_params)
    for path, value in flat_params.items():
        name = ".".join(path)
        if name.startswith(prefix):
            line = f"{name} {value.shape} \n {value} \n \n"
            if file: print(line, file=file) 
            else: print(line)

class FlaxViTMoEBlock(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32
    def setup(self):
        self.embed_dim      = self.config.hidden_size
        self.num_routed     = getattr(self.config, "num_routed_experts")
        self.num_shared     = getattr(self.config, "num_shared_experts")
        self.topk           = getattr(self.config, "topk")
        # ===== routed experts (gated) ========================================
        self.routed_intermediates: List[FlaxViTIntermediate] = [
            FlaxViTIntermediate(config=self.config, dtype=self.dtype) for _ in range(self.num_routed)
        ]
        self.routed_outputs: List[FlaxViTOutput] = [
            FlaxViTOutput(config=self.config, dtype=self.dtype) for _ in range(self.num_routed)
        ]
        self.gate = nn.Dense(
            self.num_routed, dtype=self.dtype, kernel_init=nn.initializers.normal(stddev=1e-4),bias_init=nn.initializers.zeros
        )
        if self.num_shared > 0:
            shared_cfg = copy.deepcopy(self.config)
            shared_cfg.intermediate_size = self.config.intermediate_size*self.num_shared
            self.shared_intermediate   = FlaxViTIntermediate(config = shared_cfg, dtype = self.dtype)
            self.shared_output         = FlaxViTOutput(config = shared_cfg, dtype =  self.dtype)
        else:
            self.shared_intermediate, self.shared_output = None, None

    def __call__(self, hidden_states, residual,deterministic: bool = True, flags: dict = None):
        B, T, _ = hidden_states.shape
        # ------------- gating -------------------------------------------------
        gate_logits        = self.gate(hidden_states)                    # (B,T,E)
        top_vals, top_idx  = jax.lax.top_k(gate_logits, self.topk)       # (B,T,k)
        gate_weights = jnp.zeros_like(gate_logits)                       # (B,T,E)
        gate_weights = gate_weights.at[
            jnp.arange(B)[:, None, None],
            jnp.arange(T)[None, :, None],
            top_idx
        ].set(nn.softmax(top_vals, axis=-1))                   
        # ------------- routed experts ----------------------------------------
        routed_outs = []
        for inter, out in zip(self.routed_intermediates, self.routed_outputs):
            h   = inter(hidden_states)                                   # (B,T,D_i)
            o   = out(h, residual, deterministic=deterministic)          # (B,T,D)
            routed_outs.append(o)
        expert_outs = jnp.stack(routed_outs, axis=-2)                    # (B,T,E,D)
        routed_outs = (expert_outs * gate_weights[..., None]).sum(axis=-2)  # (B,T,D)
        # -------- shared expert ---------------------------------------------
        if self.shared_intermediate is not None:
            h_shared   = self.shared_intermediate(hidden_states)
            shared_out = self.shared_output(h_shared, residual, deterministic)
        else:
            shared_out = 0
        if(flags is not None): 
            flags["expert_outs"] = expert_outs
            flags["gate_logits"] = gate_logits
        return routed_outs + shared_out
class FlaxViTMoELayer(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.attention = FlaxViTAttention(self.config, dtype=self.dtype)
        self.moe_block = FlaxViTMoEBlock(self.config, dtype=self.dtype)
        self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
    def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False, flags: dict = None):
        # Apply LayerNorm before attention (PreNorm)
        normed_hidden = self.layernorm_before(hidden_states)
        # Self-attention
        attention_outputs = self.attention(
            normed_hidden,
            deterministic=deterministic,
            output_attentions=output_attentions,
        )
        attention_output = attention_outputs[0]
        # Residual connection
        residual =  attention_output + hidden_states 
        # Apply LayerNorm after attention
        normed_residual = self.layernorm_after(residual)
        # MoE feed-forward block (intermediate + output)
        moe_output = self.moe_block(normed_residual, residual,deterministic=deterministic,flags = flags)
        outputs = (moe_output,)
        if output_attentions:
            outputs += (attention_outputs[1],)
        return outputs
class FlaxViTMoEBlockCollection(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32
    def setup(self):
        self.layers = [
            FlaxViTMoELayer(self.config, name=str(i), dtype=self.dtype) if i == self.config.moe_idx else FlaxViTLayer(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.num_hidden_layers)
        ]
    def __call__(
        self,
        hidden_states,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        return_expert_outs = None,
        return_expert_gate = None,
        flags: dict = None
    ):
        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None
        for i, layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            if(i == self.config.moe_idx):
                layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions,flags = flags)
            else: 
                layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
            hidden_states = layer_outputs[0]
            if output_attentions:
                all_attentions += (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if not return_dict:
            outputs = (hidden_states,)
            if output_hidden_states:
                outputs += (all_hidden_states,)
            if output_attentions:
                outputs += (all_attentions,)
            return tuple(v for v in outputs if v is not None)

        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
        )
class FlaxViTMoEEncoder(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    def setup(self):
        self.layer = FlaxViTMoEBlockCollection(self.config, dtype=self.dtype)
    def __call__(
        self,
        hidden_states,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        return_expert_outs = None,
        return_expert_gate = None,
        flags: dict = None
    ):
        return self.layer(
            hidden_states,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            flags = flags,
        )

class FlaxViTMoEModule(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    add_pooling_layer: bool = True

    def setup(self):
        self.embeddings = FlaxViTEmbeddings(self.config, dtype=self.dtype)
        self.encoder = FlaxViTMoEEncoder(self.config, dtype=self.dtype)
        self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.pooler = FlaxViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None

    def __call__(
        self,
        pixel_values,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        return_expert_outs = None,
        return_expert_gate = None,
        flags: dict = None
    ):
        if pixel_values.ndim == 4 and pixel_values.shape[1] == self.config.num_channels:
            pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
        hidden_states = self.embeddings(pixel_values, deterministic=deterministic)
        outputs = self.encoder(
            hidden_states,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            flags = flags,
        )
        hidden_states = outputs[0]
        hidden_states = self.layernorm(hidden_states)
        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None

        if not return_dict:
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]

        return FlaxBaseModelOutputWithPooling(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
class FlaxViTMoEModel(FlaxViTPreTrainedModel):
    module_class = FlaxViTMoEModule
class FlaxViTMoEForImageClassificationModule(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.vit = FlaxViTMoEModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
        self.classifier = nn.Dense(
            self.config.num_labels,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
        )

    def __call__(
        self,
        pixel_values=None,
        deterministic: bool = True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs ,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.vit(
            pixel_values,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            flags = kwargs.get("flags", None),
        )

        hidden_states = outputs[0]
        logits = self.classifier(hidden_states[:, 0, :])  # class token

        if not return_dict:
            return (logits,) + outputs[2:]

        return FlaxSequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class FlaxViTMoEForImageClassification(FlaxViTPreTrainedModel):
    module_class = FlaxViTMoEForImageClassificationModule
