import dataclasses
from typing import Callable, Optional, Union
from megablocks.layers import common
from megablocks.layers import moe
from megablocks.layers import dmlp_registry
from megablocks.layers import mpu 
from megablocks.layers import router
import megablocks.ops as ops
import numpy as np
import stk
import torch
import copy
from megablock_moa.mb_linear import GroupedMapReduce
from flash_attn import flash_attn_func
from functools import partial
from torch import nn
from rotary_embedding_torch import RotaryEmbedding
from megablocks.layers.arguments import Arguments as MBArguments

from benchmark_moe import test_strategy
import pickle

from simplemoa import MoA


@dataclasses.dataclass
class Arguments:
    # Model arguments.
    hidden_size : int = 1024
    num_heads : int = 4
    head_size : int = 64
    moa_num_experts : int = 16
    moa_top_k : int = 4
    attn_dropout : float = 0.
    init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)


def promote_scalar(x):
    return x.view(1) if not len(x.size()) else x


class ParallelDroplessMapReduce(moe.ParallelMLP):

    def __init__(self, args: MBArguments):
        super(ParallelDroplessMapReduce, self).__init__(args)
        self.hidden_size = args.hidden_size
        self.ffn_hidden_size = mpu.features_per_rank(args)
        self.blocking = 128
        self.linear = GroupedMapReduce(args)

        # Calculate the number of bits needed to represent the column indices
        # in the intermediate sparse matrix.
        max_column_index = (
                (self.ffn_hidden_size * self.num_experts) // self.blocking)
        self.transpose_sort_end_bit = max(
            int(np.ceil(np.log2(max_column_index))), 1)

    def map(self, x, indices, bin_ids, bins, tokens_per_expert):
        # x: [sl, bs, hs]
        # expert_weights: [sl * bs, top-k]
        # top_experts: [sl * bs, top-k]
        sl, bs, hs = x.shape
        out = self.grouped_permute_and_compute(
            x,
            tokens_per_expert,
            indices,
            bin_ids,
            None,
            bins,
            -1,  # unused
            self.top_k,
            map=True)
        return out.view(sl, bs, self.top_k, self.ffn_hidden_size)

    def reduce(self, x, indices, bin_ids, bins, tokens_per_expert, expert_weights):
        # x: [sl, bs, k, hs]
        sl, bs, k, hs = x.shape

        out = self.grouped_permute_and_compute(
            x,
            tokens_per_expert,
            indices,
            bin_ids,
            expert_weights,
            bins,
            -1,  # unused
            self.args.moe_top_k,
            map=False)
        return out.view(sl, bs, self.hidden_size)

    def grouped_permute_and_compute(
            self,
            x,
            tokens_per_expert,
            indices,
            bin_ids,
            expert_weights,
            bins,
            expert_capactiy,  # unused
            top_k,
            map=True):

        # Route the tokens for MoE computation.
        x = x.view(-1, x.shape[-1])
        x = ops.gather(
            x,
            indices,
            bin_ids,
            bins,
            top_k if map else 1)

        # Perform the expert computation.
        if map:
            x = self.linear.map(x, tokens_per_expert)
        else:
            x = self.linear.reduce(x, tokens_per_expert)
        # Un-route the data for the MoE output.
        return ops.scatter(
            x,
            indices,
            bin_ids,
            expert_weights,
            bins,
            top_k if not map else 1,
            self.args.quantize_scatter_num_bits)


class dMoA(torch.nn.Module):

    def __init__(self, args: Arguments):
        super(dMoA, self).__init__()

        # Token router.
        # self.router = router.LearnedRouter(args)
        self.att_hidden = args.num_heads * args.head_size

        self.args = MBArguments(
            hidden_size=args.hidden_size,
            ffn_hidden_size=self.att_hidden,
            moe_num_experts=args.moa_num_experts,
            moe_capacity_factor=1,
            moe_top_k=args.moa_top_k,
            init_method=args.init_method,
            fp16=False,
            bf16=False,
            bias=False
        )
        self.router = router.LearnedRouter(self.args)

        # Expert computation helper.
        self.experts = ParallelDroplessMapReduce(self.args)

        self.k_proj = nn.Linear(args.hidden_size, self.att_hidden)
        self.v_proj = nn.Linear(args.hidden_size, self.att_hidden)
        # regularization
        self.attn_dropout = nn.Dropout(args.attn_dropout)
        # causal mask to ensure that attention is only applied to the left in the input sequence

        self.num_heads = args.num_heads
        self.top_k = args.moa_top_k
        self.hidden_size = args.hidden_size
        self.head_size = args.head_size


        self.rotary_embed = RotaryEmbedding(self.head_size // 2)
        rope_freqs = self.rotary_embed.freqs.data
        del self.rotary_embed.freqs
        self.rotary_embed.register_buffer("freqs", rope_freqs)

    def get_aux_loss_and_clear(self):
        return self.router.get_aux_loss_and_clear()

    def map(self, x):
        # NOTE: If we're going to cast the activations to lower precision
        # do it before we permute the tokens to save bandwidth.
        x = common.cast_if_autocast_enabled(x)

        # Compute the expert scores and assignments.
        scores, expert_weights, top_experts = self.router(x)

        # Compute the experts.
        return self.experts.map(x, expert_weights, top_experts)

    def reduce(self, x):
        x = self.experts.reduce(x)
        return x

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        x = common.cast_if_autocast_enabled(x)
        scores, expert_weights, top_experts = self.router(x)
        expert_weights = expert_weights.flatten()
        top_experts = top_experts.flatten()
        with torch.no_grad():
            indices, bin_ids, bins, tokens_per_expert = self.experts.indices_and_bins(top_experts)
        q = self.experts.map(x, indices, bin_ids, bins, tokens_per_expert)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # k, v, hidden = self.add_history(k, v, hidden)
        context_length = k.size(1)

        q = q.view(B, T, self.top_k * self.num_heads, self.head_size)  # (B, T, k * nh, hs)
        k = k.view(B, context_length, self.num_heads, self.head_size)  # (B, T, nh, hs)
        v = v.view(B, context_length, self.num_heads, self.head_size)  # (B, T, nh, hs)

        k = k.repeat(1, 1, self.top_k, 1)  # (B, T, k * nh, hs)
        v = v.repeat(1, 1, self.top_k, 1)  # (B, T, k * nh, hs)

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        q = self.rotary_embed.rotate_queries_or_keys(q, seq_dim=-2, offset=context_length - T)
        k = self.rotary_embed.rotate_queries_or_keys(k, seq_dim=-2)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        y = flash_attn_func(q, k, v, causal=True, window_size=(context_length - T if context_length > T else -1, -1))
        # output projection
        y = self.experts.reduce(y.reshape(B, T, self.top_k, self.att_hidden).type_as(x),
                                indices, bin_ids, bins, tokens_per_expert, expert_weights)

        y = y.view(B, T, C)  # re-assemble all head outputs side by side
        return y

class BasicMoA(torch.nn.Module):
    def __init__(self, hidden_size: int = 1024, num_heads: int = 4, head_size: int = 128, attn_dropout: float = 0.):
        super(BasicMoA, self).__init__()
        # Token router.
        # self.router = router.LearnedRouter(args)
        self.att_hidden_size = num_heads * head_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = head_size

        # Expert computation helper.
        self.q_proj = nn.Linear(hidden_size, self.att_hidden_size)
        self.k_proj = nn.Linear(hidden_size, self.att_hidden_size)
        self.v_proj = nn.Linear(hidden_size, self.att_hidden_size)
        self.out_proj = nn.Linear(self.att_hidden_size, hidden_size)
        # regularization
        self.attn_dropout = nn.Dropout(attn_dropout)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.rotary_embed = RotaryEmbedding(self.head_size // 2)
        rope_freqs = self.rotary_embed.freqs.data
        del self.rotary_embed.freqs
        self.rotary_embed.register_buffer("freqs", rope_freqs)

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        # top_experts = top_experts.flatten()

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # k, v, hidden = self.add_history(k, v, hidden)
        context_length = k.size(1)
        q = q.view(B, T, self.num_heads, self.head_size)  # (B, T, k * nh, hs)
        k = k.view(B, context_length, self.num_heads, self.head_size)  # (B, T, nh, hs)
        v = v.view(B, context_length, self.num_heads, self.head_size)  # (B, T, nh, hs)

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        q = self.rotary_embed.rotate_queries_or_keys(q, seq_dim=-2, offset=context_length - T)
        k = self.rotary_embed.rotate_queries_or_keys(k, seq_dim=-2)
        q = q.permute(0, 2, 1, 3).contiguous()
        k = k.permute(0, 2, 1, 3)
        y = flash_attn_func(q, k, v, causal=True,
                            window_size=(context_length - T if context_length > T else -1, -1))
        # output projection
        y = self.out_proj(y.view(-1, self.att_hidden_size))
        y = y.view(B, T, C)  # re-assemble all head outputs side by side
        return y


def init_megablocks(xdim, heads_per_expert, head_size, E, k, dtype):
    args = Arguments(
        hidden_size=xdim,
        num_heads=heads_per_expert,
        head_size=head_size,
        moa_num_experts=E,
        moa_top_k=k
    )
    return dMoA(args).cuda().to(dtype)

def init_simple(xdim, heads_per_expert, head_size, E, k, dtype):
    args = Arguments(
        hidden_size=xdim,
        num_heads=heads_per_expert,
        head_size=head_size,
        moa_num_experts=E,
        moa_top_k=k
    )
    return MoA(
        hidden_size=args.hidden_size,
        num_heads=args.num_heads,
        head_size=args.head_size,
        attn_dropout=args.attn_dropout,
        num_experts=args.moa_num_experts,
        top_k=args.moa_top_k,
    ).cuda().to(dtype)

def init_small_base(xdim, heads_per_expert, head_size, E, k, dtype):
    return BasicMoA(
        hidden_size=xdim,
        num_heads=k * heads_per_expert,
        head_size=head_size,
        attn_dropout=0.,
    ).cuda().to(dtype)

def init_large_base(xdim, heads_per_expert, head_size, E, k, dtype):
    return BasicMoA(
        hidden_size=xdim,
        num_heads=E * heads_per_expert,
        head_size=head_size,
        attn_dropout=0.,
    ).cuda().to(dtype)

def test_params(init_strat, B, L, xdim, heads_per_expert, head_size, E, k, dtype):
    X = torch.randn(B, L, xdim, dtype=dtype).cuda()
    DY = torch.randn_like(X)
    X.requires_grad_(True)
    strat = init_strat(xdim, heads_per_expert, head_size, E, k, dtype)
    return test_strategy(
        fwd=lambda: strat(X),
        bwd=lambda y: y.backward(DY)
    )


if __name__ == '__main__':
    dtype = torch.bfloat16
    B = 16
    L = 2048
    sparse_factor = 8
    xdim = 4096
    nheads = 32
    head_dim = xdim // nheads # 128

    results = {}
    dense_small = test_params(init_small_base, B, L, xdim, nheads, head_dim, 8, 1, dtype)
    dense_large = test_params(init_large_base, B, L, xdim, nheads, head_dim, 8, 1, dtype)
    for k in [1, 2, 4, 8, 16]:
        E = k * sparse_factor
        heads_per_expert = nheads // k
        print(heads_per_expert, head_dim, E, k, k* heads_per_expert, E * heads_per_expert)
        results[k] = {
            'simple': test_params(init_simple, B, L, xdim, heads_per_expert, head_dim, E, k, dtype),
            'megablocks': test_params(init_megablocks, B, L, xdim, heads_per_expert, head_dim, E, k, dtype),
        }
    pickle.dump((dense_small, dense_large, results), open('moa_k_benchmarks.pkl', 'wb'))

    for label, init_strat in [
            ("simple", init_simple), 
            ("megablocks", init_megablocks)
        ]:
        X = torch.randn(B, L, xdim, dtype=dtype).cuda()
        DY = torch.randn_like(X)
        X.requires_grad_(True)
        strat = init_strat(xdim, heads_per_expert, head_size, E, k, dtype)
        torch.cuda.memory._record_memory_history()
        Y = strat(X)
        Y.backward(DY)
        torch.cuda.memory._dump_snapshot("moa_%s_memory.pkl" % label)


