import torch
from megablocks.layers import router
from torch import nn
from rotary_embedding_torch import RotaryEmbedding
from smoe.kernels.ops import flatten_and_sort, padded_block_indices
from smoe.parallel_experts import ParallelExperts
from flash_attn import flash_attn_func
from megablocks.layers.arguments import Arguments


class MoA(torch.nn.Module):

    def __init__(self, hidden_size : int = 1024, num_heads : int = 4, head_size : int = 128,
                 num_experts : int = 8, top_k : int = 2, attn_dropout : float = 0.):
        super(MoA, self).__init__()

        # Token router.
        self.att_hidden_size = num_heads * head_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.num_experts = num_experts
        self.top_k = top_k
        self.head_size = head_size

        self.router = router.LearnedRouter(
            Arguments(hidden_size=hidden_size, moe_num_experts=num_experts, moe_top_k=self.top_k))

        # Expert computation helper.
        self._q_proj = ParallelExperts(num_experts, hidden_size, self.att_hidden_size)
        self.k_proj = nn.Linear(hidden_size, self.att_hidden_size)
        self.v_proj = nn.Linear(hidden_size, self.att_hidden_size)
        self._out_proj = ParallelExperts(num_experts, self.att_hidden_size, hidden_size)
        # regularization
        self.attn_dropout = nn.Dropout(attn_dropout)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.rotary_embed = RotaryEmbedding(self.head_size // 2)
        rope_freqs = self.rotary_embed.freqs.data
        del self.rotary_embed.freqs
        self.rotary_embed.register_buffer("freqs", rope_freqs)
    def q_proj(self, x, top_experts):
        with torch.no_grad():
            sorted_expert_idxs, sorted_scattered_idxs = flatten_and_sort(top_experts)
            padded_block_idxs, expert_offsets = padded_block_indices(sorted_expert_idxs)
        x = x.view(-1, x.size(-1))
        out = self._q_proj(x, self.top_k, sorted_expert_idxs, sorted_scattered_idxs, padded_block_idxs, expert_offsets)
        return out, sorted_expert_idxs, sorted_scattered_idxs, padded_block_idxs, expert_offsets

    def out_proj(self, y, expert_p, sorted_expert_idxs, sorted_scattered_idxs, padded_block_idxs, expert_offsets):
        y = y.view(-1, self.att_hidden_size)
        out = self._out_proj(
            y, 1, sorted_expert_idxs, sorted_scattered_idxs, padded_block_idxs, expert_offsets,
            gates=expert_p
        )
        return out

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        scores, expert_weights, top_experts = self.router(x)
        # top_experts = top_experts.flatten()
        
        q, sorted_expert_idxs, sorted_scattered_idxs, padded_block_idxs, expert_offsets = self.q_proj(x, top_experts)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # k, v, hidden = self.add_history(k, v, hidden)
        context_length = k.size(1)
        q = q.view(B, T, self.top_k * self.num_heads, self.head_size)  # (B, T, k * nh, hs)
        k = k.view(B, context_length, self.num_heads, self.head_size)  # (B, T, nh, hs)
        v = v.view(B, context_length, self.num_heads, self.head_size)  # (B, T, nh, hs)

        k = k.repeat(1, 1, self.top_k, 1)  # (B, T, k * nh, hs)
        v = v.repeat(1, 1, self.top_k, 1)  # (B, T, k * nh, hs)

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        q = self.rotary_embed.rotate_queries_or_keys(q, seq_dim=-2, offset=context_length - T)
        k = self.rotary_embed.rotate_queries_or_keys(k, seq_dim=-2)
        q = q.permute(0, 2, 1, 3).contiguous()
        k = k.permute(0, 2, 1, 3)
        y = flash_attn_func(q, k, v, causal=True, window_size=(context_length - T if context_length > T else -1, -1))
        # output projection
        y = self.out_proj(y, expert_weights, sorted_expert_idxs, sorted_scattered_idxs, padded_block_idxs, expert_offsets)
        y = y.view(B, T, C)  # re-assemble all head outputs side by side
        return y
