"""
Our code is built on customized TorchTitan implementation and Official Megablocks 
Core Implementation is on line 86-91
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from typing import List

import stk

import megablocks.ops as ops


from megablocks import Arguments
from megablocks.layers import sharedexpert_registry
from megablocks.layers.dmoe import common, dMoE, ParallelDroplessMLP
from megablocks.layers.router import LearnedRouter


from torchtitan.protocols.train_spec import ModelProtocol
from torchtitan.experiments.llama4 import TransformerModelArgs, AttentionKVCache
from torchtitan.experiments.llama4 import Attention, FeedForward, precompute_freqs_cis



class ParallelDroplessMoEMLP(ParallelDroplessMLP):
    def __init__(self, args):
        super().__init__(args)
        self.args = args

        self.aux_loss_coef = 0.01
        # UoE has no shared expert
        self.shared_expert = None
        

    def sparse_forward_once(self, x, expert_weights, top_experts):
        # x: [sl, bs, hs]
        # expert_weights: [sl * bs, top-k]
        # top_experts: [sl * bs, top-k]
        expert_weights = expert_weights.flatten()
        top_experts = top_experts.flatten()
        
        with torch.no_grad():
            indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))

        # Route the tokens for MoE computation.
        x = x.view(-1, x.shape[-1])

        x = ops.padded_gather(
            x,
            indices, 
            bin_ids,
            bins,
            padded_bins,
            self.top_k,
        )

        # Create the sparse matrix topology.
        with torch.no_grad():
            topo = self.topology(x, padded_bins)

        x = self.mlp(x, topo)

        # Un-route the data for the MoE output.
        x = ops.padded_scatter(x,
            indices,
            bin_ids,
            expert_weights,
            bins,
            padded_bins,
            self.top_k,
        )
        return x
    
    def forward(self, x):
        args = self.args

        in_shape = x.shape
        slen, bs, _ = in_shape

        x = x.view(-1, x.shape[-1])

        top_neuron = torch.arange(args.moe_num_experts, device=self.mlp.w1.device)[:, None] * args.ffn_hidden_size + torch.arange(8, device=self.mlp.w1.device)[None, :]
        top_neuron = top_neuron.view(-1)
        w1_sparse, v1_sparse = self.mlp.w1[top_neuron], self.mlp.v1[top_neuron]  
        w2_sparse = self.mlp.w2[top_neuron]
        expert_acts = F.silu(torch.mm(x, w1_sparse.T)) * torch.mm(x, v1_sparse.T)
        shared_out = torch.mm(expert_acts, w2_sparse)
        
        expert_acts = expert_acts.view(x.shape[0], args.moe_num_experts, -1) 
        # for post-norm setting
        logits = torch.norm(expert_acts, p=2, dim=-1) 

        expert_weights, top_experts = torch.topk(logits, k=self.args.moe_top_k, dim=-1)
        expert_weights = expert_weights.softmax(-1, dtype=torch.float32)

        out = self.sparse_forward_once(x, expert_weights, top_experts)
        bl_loss = self.load_balancing_loss_func(top_experts, logits, bs, slen)

        out += shared_out

        out = out.view(in_shape)

        return out, bl_loss


    def load_balancing_loss_func(self, top_experts, scores, bs, slen):
        num_experts = self.args.moe_num_experts

        one_hot_topk = F.one_hot(top_experts, num_classes=num_experts).sum(dim=1).float()
        one_hot_topk = one_hot_topk.view(slen, bs, num_experts)
        fi = one_hot_topk.sum(dim=0) * num_experts / (self.top_k * slen)  # (bs, num_experts)

        scores_normalized = scores.softmax(dim=-1, dtype=torch.float32)

        pi = scores_normalized.view(slen, bs, self.args.moe_num_experts).sum(dim=0) / slen
        bl_loss = self.aux_loss_coef * (fi * pi).sum(dim=-1).mean()

        return bl_loss


class dAoE(dMoE):
    def __init__(self, args):
        super().__init__(args)
        self.args = args
        del self.router
        # titan implementation of auxiliary-loss-free load balancing
        self.load_balance_coeff = args.load_balance_coeff

        # the fields below are defined even when load_balance_coeff is None to make initialization and checkpointing code simpler
        self.register_buffer("expert_bias", torch.zeros(args.moe_num_experts, dtype=torch.float32), persistent=True)
        self.register_buffer("tokens_per_expert", torch.zeros(args.moe_num_experts, dtype=torch.float32), persistent=True)

        # NOTE: forward hook, forward pre hook, or backward pre hook
        #       would conflict with activation checkpointing
        if self.load_balance_coeff is not None and self.load_balance_coeff > 0:
            print(f'registed backward hook for moe expert bias: {self.load_balance_coeff}')
            self.register_full_backward_hook(self._update_expert_bias)
    
    def _init_experts_mlp(self, args: Arguments):
        return ParallelDroplessMoEMLP(args)
           
    def forward(self, x: torch.Tensor):
        # x: [bs, sl, hs] -> [sl, bs, hs]
        x = common.cast_if_autocast_enabled(x)

        x = x.transpose(0, 1).contiguous()
        in_shape = x.shape
        out, bl_loss = self.experts(x)

        out = out.view(in_shape).transpose(0, 1).contiguous()

        return out, bl_loss
    
    def _update_expert_bias(self, *_):        
        expert_bias_delta = self.load_balance_coeff * torch.sign(
            self.tokens_per_expert.mean() - self.tokens_per_expert
        )
        expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
        self.expert_bias.add_(expert_bias_delta)    
        self.tokens_per_expert.zero_()


class AoEBlock(nn.Module):
    def __init__(self, layer_id: int, model_args: TransformerModelArgs):
        super().__init__()
        self.moe_enabled = True

        self.n_heads = model_args.n_heads

        attn_use_rope = True
        fixed_attn_block_size = None
        self.attention = Attention(model_args, attn_use_rope, fixed_attn_block_size)
        
        # TODO: as MegaBlocks requires ffn_hidden_size divides 128
        # we manully set it as 512 to 
        hidden_dim = 512

        args = Arguments(
            # Model arguments.
            hidden_size=model_args.dim,
            ffn_hidden_size=model_args.ffn_hidden_size if model_args.ffn_hidden_size is not None else hidden_dim,
            bias=False,
            return_bias=False,
            activation_fn=F.silu,   # swiglu
            # MoE arguments.
            # If expert_capacity is set to zero, set the number of tokens
            # per expert to the maximum we need to avoid dropping tokens.
            moe_num_experts=model_args.num_experts,
            moe_top_k=model_args.top_k,
            moe_capacity_factor=0,
            # Parallelism arguments.
            # by default
            # Compute arguments.
            mlp_type='glu',
            mlp_impl='sparse',
            # Initialization arguments. ~ FSDP
            fp16=False,
            bf16=False,
            init_method=partial(torch.nn.init.normal_, mean=0.0, std=0.02 / (2 * (layer_id + 1)) ** 0.5),
            # shared expert arguments
            shared_expert=False,  # enable using shared expert
        )
        setattr(args, 'load_balance_coeff', model_args.load_balance_coeff)

        self.moe = dAoE(args)
        
        self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
        self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)

        self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        kv_cache,
    ):
        """
        Perform a forward pass through the TransformerBlock.

        Args:
            x (torch.Tensor): Input tensor.
            freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.

        Returns:
            torch.Tensor: Output tensor after applying attention and feedforward layers.

        """
        h = x + self.attention(self.attention_norm(x), freqs_cis, kv_cache)
        
        out, bl_loss = self.moe(self.ffn_norm(h))
        out = h + out

        return out, bl_loss

    def init_weights(self, buffer_device: torch.device):
        for norm in (self.attention_norm, self.ffn_norm):
            norm.reset_parameters()
        self.attention.init_weights(self.weight_init_std)
        # megablocks have written the moe init
        # self.moe.init_weights(self.weight_init_std, buffer_device)


class TransformerAoE2Shared(nn.Module, ModelProtocol):
    """
    Transformer Module

    Args:
        model_args (TransformerModelArgs): Model configuration arguments.

    Attributes:
        model_args (TransformerModelArgs): Model configuration arguments.
        vocab_size (int): Vocabulary size.
        n_layers (int): Number of layers in the model.
        tok_embeddings (ParallelEmbedding): Token embeddings.
        layers (torch.nn.ModuleList): List of Transformer blocks.
        norm (RMSNorm): Layer normalization for the model output.
        output (ColumnParallelLinear): Linear layer for final output.
        freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.

    """

    def __init__(self, model_args: TransformerModelArgs):
        super().__init__()

        self.model_args = model_args
        self.vocab_size = model_args.vocab_size
        self.n_layers = model_args.n_layers
        self.eos_id = model_args.eos_id

        self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)

        self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)

        self.layers = torch.nn.ModuleDict()
        for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = AoEBlock(layer_id, model_args)
        self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
        self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
        self.init_weights()

    def init_weights(
        self,
        buffer_device: torch.device | None = None,
    ):
        buffer_device = buffer_device or self.freqs_cis.device
        with torch.device(buffer_device):
            self.freqs_cis = self._precompute_freqs_cis()
        if self.tok_embeddings is not None:
            nn.init.normal_(self.tok_embeddings.weight)
        for layer in self.layers.values():
            if layer is not None:
                layer.init_weights(buffer_device=buffer_device)
        if self.norm is not None:
            self.norm.reset_parameters()
        final_out_std = self.model_args.dim**-0.5
        cutoff_factor = 3
        if self.output is not None:
            nn.init.trunc_normal_(
                self.output.weight,
                mean=0.0,
                std=final_out_std,
                a=-cutoff_factor * final_out_std,
                b=cutoff_factor * final_out_std,
            )

    def _precompute_freqs_cis(self) -> torch.Tensor:
        return precompute_freqs_cis(
            self.model_args.dim // self.model_args.n_heads,
            # Need to compute until at least the max token limit for generation
            # TODO: explain in docs/composability.md why we removed the 2x
            # relaxing in our CP enablement PR
            self.model_args.max_seq_len,
            self.model_args.rope_theta,
        )

    def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None, kv_caches: List[AttentionKVCache] | None = None):
        # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
        h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

        all_bl_loss = 0
        for layer_idx, layer in enumerate(self.layers.values()):
            kv_cache = kv_caches[layer_idx] if kv_caches is not None else None
            h, bl_loss = layer(h, self.freqs_cis, kv_cache)

            all_bl_loss += bl_loss

        h = self.norm(h) if self.norm else h
        output = self.output(h) if self.output else h
        return output, all_bl_loss / self.n_layers

    @classmethod
    def from_model_args(cls, model_args: TransformerModelArgs) -> "Transformer":
        """
        Initialize a Transformer model from a TransformerModelArgs object.

        Args:
            model_args (TransformerModelArgs): Model configuration arguments.

        Returns:
            Transformer: Transformer model.

        """
        return cls(model_args)


        




