import itertools
import math
import os
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange


# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
@triton.jit
def _fwd_kernel(
    Q,
    K,
    V,
    Out,
    S,  # log lambda
    b: tl.constexpr,
    h: tl.constexpr,
    n: tl.constexpr,
    d: tl.constexpr,
    e: tl.constexpr,
    BLOCK: tl.constexpr,
    NUM_BLOCK: tl.constexpr,
    BLOCK_MODEL: tl.constexpr,
):
    ##### get offset
    off_bh = tl.program_id(0)
    off_h = off_bh % h
    off_e = tl.program_id(1)
    qk_offset = off_bh * n * d
    v_offset = off_bh * n * e
    o_offset = off_bh * n * e
    # channel offset
    e_offset = off_e * BLOCK_MODEL

    ##### get block ptr
    Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :]
    K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None]
    V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
    O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]
    S_block_ptr = S + off_h

    ##### init diag decay(Lambda); q, k decay; kv
    s = tl.load(S_block_ptr)
    # q, k decay
    off_block = tl.arange(
        0, BLOCK
    )  # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent
    q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None])
    k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - off_block[None, :]))
    block_decay = tl.exp(-s.to(tl.float32) * BLOCK)
    # diag decay
    index = off_block[:, None] - off_block[None, :]
    s_index = s * index
    s_index = tl.where(index >= 0, -s_index, float("-inf"))
    diag_decay = tl.exp(s_index)
    kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32)

    ##### compute
    for i in range(NUM_BLOCK):
        # load
        q = tl.load(
            Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0
        ).to(tl.float32)
        k_trans = tl.load(
            K_trans_block_ptr + off_block[None, :] * d,
            mask=off_block[None, :] < n,
            other=0.0,
        ).to(tl.float32)
        v = tl.load(
            V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0
        ).to(tl.float32)

        # compute
        qk = tl.dot(q, k_trans) * diag_decay
        o_intra = tl.dot(qk, v)
        o_inter = tl.dot(q, kv) * q_decay
        o = o_intra + o_inter

        # save and update
        tl.store(
            O_block_ptr + off_block[:, None] * e,
            o.to(O_block_ptr.dtype.element_ty),
            mask=off_block[:, None] < n,
        )
        kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v)
        off_block += BLOCK


def lightning_attn2(q, k, v, s):
    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()
    s = s.contiguous()

    b, h, n, d = q.shape
    e = v.shape[-1]

    # Pad d to next power of 2
    d_padded = next_power_of_2(d)
    if d_padded != d:
        q_padded = F.pad(q, (0, d_padded - d))
        k_padded = F.pad(k, (0, d_padded - d))
    else:
        q_padded = q
        k_padded = k

    # Pad e to next power of 2
    e_padded = next_power_of_2(e)
    if e_padded != e:
        v_padded = F.pad(v, (0, e_padded - e))
    else:
        v_padded = v

    o_padded = torch.empty((b, h, n, e_padded), dtype=q.dtype, device=q.device)

    BLOCK = 64
    NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK)
    # parallel over channel
    BLOCK_MODEL = min(triton.next_power_of_2(e_padded), 32)
    grid = (b * h, triton.cdiv(e_padded, BLOCK_MODEL))

    _fwd_kernel[grid](
        q_padded,
        k_padded,
        v_padded,
        o_padded,
        s,
        b,
        h,
        n,
        d_padded,
        e_padded,
        BLOCK=BLOCK,
        NUM_BLOCK=NUM_BLOCK,
        BLOCK_MODEL=BLOCK_MODEL,
    )

    # Remove padding from output
    if e_padded != e:
        o = o_padded[..., :e]
    else:
        o = o_padded

    return o


def is_support(dim):
    return 16 % dim


def next_power_of_2(n):
    return 2 ** (int(math.ceil(math.log(n, 2))))


def lightning_attn_func(q, k, v, s):
    b, h, n, d = q.shape
    e = v.shape[-1]
    assert is_support(d) and is_support(e)

    # pad v's feature dim to power of 2
    e_pad = next_power_of_2(e)
    need_pad = e_pad != e
    if need_pad:
        v = F.pad(v, (0, e_pad - e))

    if d > 128:
        # split over head
        if 64 % d:
            m = 64
        elif 32 % d:
            m = 32
        elif 16 % d:
            m = 16
        arr = [m * i for i in range(d // m + 1)]
        if arr[-1] != d:
            arr.append(d)
        n = len(arr)
        o = 0
        for i in range(n - 1):
            start = arr[i]
            end = arr[i + 1]
            q1 = q[..., start:end]
            k1 = k[..., start:end]
            o += lightning_attn2(q1, k1, v, s)
    else:
        o = lightning_attn2(q, k, v, s)

    if need_pad:
        o = o[:, :, :, :e]

    return o


debug = eval(os.environ.get("debug", default="False"))

BLOCK = 256


# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01
class MiniMaxText01RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        MiniMaxText01RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
def get_activation_fn(activation):
    if debug:
        logger.info(f"activation: {activation}")
    if activation == "gelu":
        return F.gelu
    elif activation == "relu":
        return F.relu
    elif activation == "elu":
        return F.elu
    elif activation == "sigmoid":
        return F.sigmoid
    elif activation == "exp":

        def f(x):
            with torch.no_grad():
                x_max = torch.max(x, dim=-1, keepdims=True).values
            y = torch.exp(x - x_max)

            return y

        return f
    elif activation == "leak":
        return F.leaky_relu
    elif activation == "1+elu":

        def f(x):
            return 1 + F.elu(x)

        return f
    elif activation == "2+elu":

        def f(x):
            return 2 + F.elu(x)

        return f
    elif activation == "silu" or activation == "swish":
        return F.silu
    elif activation == "sine":
        return torch.sin
    else:
        logger.info(f"activation: does not support {activation}, use Identity!!!")
        return lambda x: x


# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py
class MiniMaxText01LightningAttention(nn.Module):
    def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs):
        super().__init__()
        if config is None:
            config = type("Config", (), kwargs)

        bias = False
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)

        self.out_proj = nn.Linear(
            self.head_dim * self.num_heads, self.hidden_size, bias=bias
        )
        self.act = get_activation_fn(config.hidden_act)
        self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads)

        self.qkv_proj = nn.Linear(
            self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias
        )
        self.output_gate = nn.Linear(
            self.hidden_size, self.head_dim * self.num_heads, bias=bias
        )

        # for inference only
        self.offset = 0
        self.layer_idx = layer_idx

    def forward(
        self,
        hidden_states,
        attn_mask: Optional[torch.Tensor] = None,  # (b, h, n, m)
        output_attentions: bool = False,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        use_cache: bool = False,
        slope_rate: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        if (not self.training) and (not do_eval):
            return self.inference(
                hidden_states,
                attn_mask,
                output_attentions,
                past_key_value,
                use_cache,
                slope_rate,
            )

    def inference(
        self,
        x,
        attn_mask: Optional[torch.Tensor] = None,  # (b, n)
        output_attentions: bool = False,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        use_cache: bool = False,
        slope_rate: Optional[torch.Tensor] = None,  # (h, 1, 1)
    ):
        # x: b n d
        b, n, d = x.shape
        # linear map
        qkv = self.act(self.qkv_proj(x))
        new_shape = qkv.size()[:-1] + (self.num_heads, -1)
        qkv = qkv.view(*new_shape)
        q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        if past_key_value is None:
            self.offset = q.shape[-2]
        else:
            self.offset += 1

        # for align with metaseq
        ratio = torch.exp(-slope_rate)

        # only use for the first time
        if past_key_value is None:
            slope_rate = slope_rate.to(torch.float32)
            if attn_mask is not None:
                v = v.masked_fill(
                    (1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0
                )
            NUM_BLOCK = (n + BLOCK - 1) // BLOCK
            b, h, n, d = q.shape
            e = v.shape[-1]
            # other
            array = torch.arange(BLOCK).to(q) + 1
            q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
            k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
            index = array[:, None] - array[None, :]
            s_index = (
                slope_rate
                * index[
                    None,
                    None,
                ]
            )
            s_index = torch.where(index >= 0, -s_index, float("-inf"))
            diag_decay = torch.exp(s_index)

            kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
            output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
            for i in range(NUM_BLOCK):
                si = i * BLOCK
                ei = min(si + BLOCK, n)
                m = ei - si
                qi = q[:, :, si:ei].contiguous()
                ki = k[:, :, si:ei].contiguous()
                vi = v[:, :, si:ei].contiguous()
                qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32)

                # diag
                qk = (
                    torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32)
                    * diag_decay[:, :, :m, :m]
                )
                qkv_diag = torch.matmul(qk, vi.to(torch.float32))
                block_decay = torch.exp(-slope_rate * m)
                output[:, :, si:ei] = qkv_none_diag + qkv_diag
                kv = block_decay * kv + torch.matmul(
                    (ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi
                )

        else:
            kv = past_key_value
            output = []
            for i in range(n):
                kv = ratio * kv + torch.einsum(
                    "... n d, ... n e -> ... d e",
                    k[:, :, i : i + 1],
                    v[:, :, i : i + 1],
                )
                qkv = torch.einsum(
                    "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
                )
                output.append(qkv)
            output = torch.concat(output, dim=-2)
        # reshape
        output = rearrange(output, "b h n d -> b n (h d)")
        # normalize
        output = self.norm(output)
        # gate
        output = F.sigmoid(self.output_gate(x)) * output
        # outproj
        output = self.out_proj(output)

        attn_weights = None

        return output, attn_weights, kv


def _build_slope_tensor(n_attention_heads: int):
    def get_slopes(n):
        def get_slopes_power_of_2(n):
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * ratio**i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(
                n
            )  # In the paper, we only train models that have 2^a heads for some a. This function has
        else:  # some good properties that only occur when the input is a power of 2. To maintain that even
            closest_power_of_2 = 2 ** math.floor(
                math.log2(n)
            )  # when the number of heads is not a power of 2, we use this workaround.
            return (
                get_slopes_power_of_2(closest_power_of_2)
                + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
            )

    # h, 1, 1
    slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
        n_attention_heads, 1, 1
    )

    return slopes


def test_lightning_attention_implementations(model_params):
    torch.manual_seed(42)

    batch_size = 2
    seq_len = 1024
    dtype = torch.bfloat16
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    hidden_states = torch.randn(
        batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device
    )

    attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)

    slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device)

    model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
    model_attn.eval()

    with torch.no_grad():
        model_output, _, _ = model_attn.inference(
            hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
        )

    qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
    new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
    qkv = qkv.view(*new_shape)
    q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)

    lib_output = lightning_attn_func(q, k, v, slope_rate)
    lib_output = lib_output.transpose(1, 2).contiguous()
    lib_output = lib_output.view(batch_size, seq_len, -1)
    lib_output = model_attn.norm(lib_output)
    lib_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
    lib_output = model_attn.out_proj(lib_output)

    torch.testing.assert_close(
        model_output,
        lib_output,
        rtol=1e-3,
        atol=1e-2,
        msg="Lightning attention implementations produce different results",
    )

    print("✅ Two implementations match")


def get_benchmark():
    batch_size_range = [2**i for i in range(0, 7)]  # max 64
    seq_length_range = [256, 512, 1024, 2048, 4096]  # max 4096
    configs = list(itertools.product(batch_size_range, seq_length_range))

    @triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=["batch_size", "seq_len"],
            x_vals=[list(_) for _ in configs],
            line_arg="provider",
            line_vals=["MiniMax-Text-01", "OpenNLPLab"],
            line_names=[
                "MiniMax-Text-01 Model Implementation",
                "OpenNLPLab Library Implementation",
            ],
            styles=[("blue", "-"), ("green", "-")],
            ylabel="us",
            plot_name="lightning-attention-prefill-performance",
            args={},
        )
    )
    def benchmark(batch_size, seq_len, provider):
        dtype = torch.bfloat16
        device = torch.device("cuda")

        params = {
            "hidden_size": 6144,
            "num_attention_heads": 64,
            "head_dim": 96,
            "hidden_act": "gelu",
        }

        hidden_states = torch.randn(
            batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device
        )

        attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device)

        slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device)
        model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
        model_attn.eval()

        quantiles = [0.5, 0.2, 0.8]
        if provider == "MiniMax-Text-01":
            ms, min_ms, max_ms = triton.testing.do_bench(
                lambda: model_attn.inference(
                    hidden_states, attn_mask=attention_mask, slope_rate=slope_rate
                ),
                quantiles=quantiles,
            )
        else:

            def run_lib():
                qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
                new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
                qkv = qkv.view(*new_shape)
                q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
                q = q.transpose(1, 2)
                k = k.transpose(1, 2)
                v = v.transpose(1, 2)

                lib_output = lightning_attn_func(q, k, v, slope_rate)
                lib_output = lib_output.transpose(1, 2).contiguous()
                lib_output = lib_output.view(batch_size, seq_len, -1)
                lib_output = model_attn.norm(lib_output)
                lib_output = (
                    torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output
                )
                return model_attn.out_proj(lib_output)

            ms, min_ms, max_ms = triton.testing.do_bench(
                run_lib,
                quantiles=quantiles,
            )

        return 1000 * ms, 1000 * max_ms, 1000 * min_ms

    return benchmark


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save_path",
        type=str,
        default="./configs/benchmark_ops/lightning_attention_prefill/",
        help="Path to save lightning attention prefill benchmark results",
    )
    args = parser.parse_args()

    # Run correctness test first
    # Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
    params = {
        "hidden_size": 6144,
        "num_attention_heads": 64,
        "head_dim": 96,
        "hidden_act": "silu",
    }
    test_lightning_attention_implementations(params)

    # Run performance benchmark
    benchmark = get_benchmark()
    benchmark.run(print_data=True, save_path=args.save_path)
