import jax.numpy as jnp
from flax import linen as nn
from typing import Optional
import math
import pickle 
from dataclasses import dataclass
from typing import Optional, Callable
from flax.linen.initializers import zeros
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict

@dataclass
class ModelConfig:
    encoder_vocab_size: int
    d_embed: int
    d_ff: int
    h: int
    N_encoder: int
    max_seq_len: int
    dropout: float
    num_experts: int = 0
    moe_idx: int = -1
    num_shared_experts: int = 0
    num_gated_experts: int = 0
    topk: int = 0
def print_model(flax_params, file=None):
    flat_params = flatten_dict(flax_params)
    for path, value in flat_params.items(): 
        name = "/".join(path)
        line = f"{name} {value.shape}"
        if file:
            print(line, file=file)
        else:
            print(line)
def load_combined_model(path):
    with open(path, "rb") as f:
        bundle = pickle.load(f)
    flax_params = bundle["flax_params"]
    config_dict = bundle["config"]
    return flax_params, config_dict
class FlaxMultiHeadedAttention(nn.Module):
    h: int
    d_embed: int
    dropout_rate: float = 0.0
    def setup(self):
        assert self.d_embed % self.h == 0
        self.d_k = self.d_embed // self.h
        self.WQ = nn.Dense(self.d_embed)
        self.WK = nn.Dense(self.d_embed)
        self.WV = nn.Dense(self.d_embed)
        self.out_proj = nn.Dense(self.d_embed)
        self.dropout = nn.Dropout(self.dropout_rate)
    def __call__(self, x_query, x_key, x_value, mask: Optional[jnp.ndarray] = None, deterministic: bool = True):
        nbatch = x_query.shape[0]
        def split_heads(x):
            # x: [batch, seq_len, d_embed] -> [batch, h, seq_len, d_k]
            return x.reshape(nbatch, -1, self.h, self.d_k).transpose(0, 2, 1, 3)
        query = split_heads(self.WQ(x_query))
        key   = split_heads(self.WK(x_key))
        value = split_heads(self.WV(x_value))
        # Scaled dot-product attention
        scores = jnp.matmul(query, jnp.swapaxes(key, -2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = jnp.where(mask, scores, float('-inf'))
        p_attn = nn.softmax(scores, axis=-1)
        p_attn = self.dropout(p_attn, deterministic=deterministic)
        x = jnp.matmul(p_attn, value)  # [batch, h, seq_len, d_k]
        x = x.transpose(0, 2, 1, 3).reshape(nbatch, -1, self.d_embed)
        return self.out_proj(x)
class FlaxResidualConnection(nn.Module):
    dim: int
    dropout_rate: float = 0.0
    def setup(self):
        self.norm = nn.LayerNorm()
        self.dropout = nn.Dropout(rate=self.dropout_rate)
    def __call__(self, x, sublayer: Callable[[jnp.ndarray], jnp.ndarray], deterministic: bool = True):
        # Apply: x + dropout(sublayer(norm(x)))
        return x + self.dropout(sublayer(self.norm(x)), deterministic=deterministic)
    
class FlaxFeedForward(nn.Module):
    d_embed: int
    d_ff: int
    dropout: float
    def setup(self):
        self.linear1 = nn.Dense(self.d_ff)
        self.linear2 = nn.Dense(self.d_embed)
        self.drop = nn.Dropout(self.dropout)
    def __call__(self, x, deterministic=True):
        x = self.linear1(x)
        x = nn.relu(x)
        x = self.drop(x, deterministic=deterministic)
        x = self.linear2(x)
        return x

class FlaxEncoderBlock(nn.Module):
    config: object  # expects fields: h, d_embed, d_ff, dropout
    def setup(self):
        self.attention = FlaxMultiHeadedAttention(self.config.h, self.config.d_embed, self.config.dropout)
        self.feed_forward = FlaxFeedForward(self.config.d_embed, self.config.d_ff, self.config.dropout)
        self.residual1 = FlaxResidualConnection(self.config.d_embed, self.config.dropout)
        self.residual2 = FlaxResidualConnection(self.config.d_embed, self.config.dropout)
    def __call__(self, x, mask: Optional[jnp.ndarray] = None, deterministic: bool = True):
        x = self.residual1(x, lambda x_: self.attention(x_, x_, x_, mask=mask, deterministic=deterministic), deterministic=deterministic)
        x = self.residual2(x, lambda x_: self.feed_forward(x_, deterministic=deterministic), deterministic=deterministic)
        return x
class FlaxEncoder(nn.Module):
    config: object  # expects fields: d_embed, encoder_vocab_size, max_seq_len, dropout, N_encoder
    def setup(self):
        self.d_embed = self.config.d_embed
        self.tok_embed = nn.Embed(num_embeddings=self.config.encoder_vocab_size, features=self.d_embed)
        self.pos_embed = self.param("pos_embed", zeros, (1, self.config.max_seq_len, self.d_embed))
        self.encoder_blocks = [FlaxEncoderBlock(self.config) for _ in range(self.config.N_encoder)]
        self.dropout = nn.Dropout(rate=self.config.dropout)
        self.norm = nn.LayerNorm()
    def __call__(self, input_ids, mask: Optional[jnp.ndarray] = None, deterministic: bool = True):
        x = self.tok_embed(input_ids)  # shape: (batch, seq_len, d_embed)
        pos = self.pos_embed[:, :x.shape[1], :]  # shape: (1, seq_len, d_embed)
        x = self.dropout(x + pos, deterministic=deterministic)
        for block in self.encoder_blocks:
            x = block(x, mask=mask, deterministic=deterministic)
        return self.norm(x)
    
class FlaxTransformer(nn.Module):
    config: object  # expects d_embed and other encoder fields
    num_classes: int

    def setup(self):
        self.encoder = FlaxEncoder(self.config)
        self.classifier = nn.Dense(self.num_classes)

    def __call__(self, x, pad_mask: Optional[jnp.ndarray] = None, deterministic: bool = True):
        # x: [batch, seq_len]
        x = self.encoder(x, mask=pad_mask, deterministic=deterministic)  # -> [batch, seq_len, d_embed]
        x = jnp.mean(x, axis=1)  # mean over seq_len
        return self.classifier(x)    # -> [batch, num_classes]
    
class FlaxFeedForwardMoE(nn.Module):
    config: object  # expects fields: d_embed, d_ff, dropout, num_experts
    def setup(self):
        assert self.config.num_experts == self.config.num_shared_experts + self.config.num_gated_experts, \
            "num_experts must be num_shared_experts + num_gated_experts"
        if self.config.num_shared_experts > 0:
            self.shared_expert = FlaxFeedForward(
                d_embed=self.config.d_embed,
                d_ff=self.config.d_ff * self.config.num_shared_experts,
                dropout=self.config.dropout,
            )
        else:
            self.shared_expert = None
        self.gated_experts = [
            FlaxFeedForward(
                d_embed=self.config.d_embed,
                d_ff=self.config.d_ff,
                dropout=self.config.dropout,
            )
            for _ in range(self.config.num_gated_experts)
        ]
        self.gate = nn.Dense(self.config.num_gated_experts)
        self.dropout_layer = nn.Dropout(self.config.dropout)

    def __call__(self, x, deterministic=True):
        """
        x: [batch, seq_len, d_embed]
        """
        B, T, D = x.shape
        E = self.config.num_gated_experts
        K = self.config.topk

        # ----- Shared FFN -----
        if self.shared_expert is not None:
            shared_out = self.shared_expert(x, deterministic=deterministic)  # [B, T, d_embed]
        else:
            shared_out = 0
        # ----- Gating -----
        logits = self.gate(x)                      # [B, T, E]
        scores = nn.softmax(logits, axis=-1)       # Softmax over experts
        topk_idx = jnp.argsort(scores, axis=-1)[..., -K:]  # [B, T, K]
        topk_mask = jnp.zeros_like(scores).at[
            jnp.arange(B)[:, None, None], jnp.arange(T)[None, :, None], topk_idx
        ].set(1.0)
        gated_weights = scores * topk_mask  # Softmax scores * Top-k mask (no normalization)
        # ----- Expert Outputs -----
        expert_outputs = [expert(x, deterministic=deterministic) for expert in self.gated_experts]  # list of [B, T, D]
        expert_stack = jnp.stack(expert_outputs, axis=-1)  # [B, T, D, E]
        gated_out = jnp.einsum("bte,btdE->btd", gated_weights, expert_stack)  # Weighted sum over experts
        # ----- Final Output -----
        return shared_out + gated_out 
class FlaxEncoderBlockMoE(nn.Module):
    config: object  # expects fields: h, d_embed, d_ff, dropout, num_experts

    def setup(self):
        self.attention = FlaxMultiHeadedAttention(self.config.h, self.config.d_embed, self.config.dropout)
        self.feed_forward = FlaxFeedForwardMoE(config=self.config) 
        self.residual1 = FlaxResidualConnection(self.config.d_embed, self.config.dropout)
        self.residual2 = FlaxResidualConnection(self.config.d_embed, self.config.dropout)

    def __call__(self, x, mask: Optional[jnp.ndarray] = None, deterministic: bool = True):
        x = self.residual1(x,lambda x_: self.attention(x_, x_, x_, mask=mask, deterministic=deterministic),deterministic=deterministic)
        x = self.residual2(x,lambda x_: self.feed_forward(x_, deterministic=deterministic),deterministic=deterministic)
        return x
class FlaxEncoderMoE(nn.Module):
    config: object

    def setup(self):
        self.d_embed = self.config.d_embed
        self.tok_embed = nn.Embed(num_embeddings=self.config.encoder_vocab_size, features=self.d_embed)
        self.pos_embed = self.param("pos_embed", nn.initializers.zeros, (1, self.config.max_seq_len, self.d_embed))
        self.blocks = [
            FlaxEncoderBlockMoE(self.config) if i == self.config.moe_idx else FlaxEncoderBlock(self.config)
            for i in range(self.config.N_encoder)
        ]
        self.dropout = nn.Dropout(rate=self.config.dropout)
        self.norm = nn.LayerNorm()
    def __call__(self, input_ids, mask: Optional[jnp.ndarray] = None, deterministic: bool = True):
        x = self.tok_embed(input_ids)
        pos = self.pos_embed[:, :x.shape[1], :]
        x = self.dropout(x + pos, deterministic=deterministic)
        for block in self.blocks:
            x = block(x, mask=mask, deterministic=deterministic)
        return self.norm(x)
class FlaxTransformerMoE(nn.Module):
    config: object  # expects: d_embed, num_experts, and other encoder fields
    num_classes: int
    def setup(self):
        self.encoder = FlaxEncoderMoE(self.config)
        self.classifier = nn.Dense(self.num_classes)
    def __call__(self, x, pad_mask: Optional[jnp.ndarray] = None, deterministic: bool = True):
        # x: [batch, seq_len]
        x = self.encoder(x, mask=pad_mask, deterministic=deterministic)  # -> [batch, seq_len, d_embed]
        x = jnp.mean(x, axis=1)  # mean over sequence length
        return self.classifier(x)    # -> [batch, num_classes]