# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional, Union

import torch
from megatron.core import parallel_state
from torch import nn
from torch.distributed._functional_collectives import all_reduce

from cosmos_predict1.autoregressive.modules.embedding import RotaryPositionEmbedding
from cosmos_predict1.autoregressive.modules.normalization import create_norm


class Attention(nn.Module):
    """
    Attenion layer with KV cache.
    """

    def __init__(
        self,
        n_heads: int,
        n_kv_heads: Union[int, None],
        dim: int,
        max_batch_size: int,
        max_seq_len: int,
        context_dim: Optional[int] = None,
        use_qk_normalization: bool = False,
        norm_type: str = "rmsnorm",
        norm_eps: float = 1e-5,
        causal_mask: Optional[bool] = True,
        head_dim: Optional[int] = None,
        fuse_qkv: bool = False,
        precision: str = "bfloat16",
        tensor_parallel_size: int = 1,
        attn_type: str = "self",
    ):
        """
        Initializes the GQA module.

        Args:
            n_heads (int): The number of attention heads.
            n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads.
            dim (int): The dimensionality of the input and output.
            max_batch_size (int): The maximum batch size.
            max_seq_len (int): The maximum sequence length.
            context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None.
            use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False.
            norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm".
            norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5.
            tp_group (int, optional): The tensor parallel group.
            causal_mask (bool, optional): Whether to use causal mask. Defaults to True.
            head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads.
            fuse_qkv (bool, optional): Whether to fuse QKV. Defaults to False.
            precision (str, optional): The precision of the module. Defaults to "bfloat16".
            tensor_parallel_size (int, optional): The tensor parallel size. Defaults to 1.
            attn_type (str, optional): The type of attention. Defaults to "self".
        """
        super().__init__()
        assert attn_type in ["self", "cross", "full"], f"Invalid attention type: {attn_type}"
        self.attn_type = attn_type
        self.tp_size = tensor_parallel_size
        context_dim = dim if context_dim is None else context_dim

        self.dim = dim
        self.context_dim = context_dim
        self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
        self.n_local_kv_heads = self.n_kv_heads // self.tp_size
        self.n_local_heads = n_heads // self.tp_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = dim // n_heads if head_dim is None else head_dim
        self.causal_mask = causal_mask
        self.fuse_qkv = fuse_qkv
        self.precision = precision

        if fuse_qkv:
            assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})"
            self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim
            self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False)
            # Register hook to load fused QKV weights
            self._register_load_state_dict_pre_hook(self.load_hook)
        else:
            self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False)
            self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False)
            self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False)

        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        if self.attn_type == "self":
            # Cache for key and value tensors
            self.init_kv_cache()

        # QK normalization layers
        if use_qk_normalization:
            assert n_heads % self.tp_size == 0, "n_heads must be divisible by tensor_model_parallel_size"
            assert self.n_kv_heads % self.tp_size == 0, "n_kv_heads must be divisible by tensor_model_parallel_size"
            self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps)
            self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps)

        self.use_qk_normalization = use_qk_normalization

        self.to(dtype=getattr(torch, self.precision))

    def load_hook(self, state_dict, prefix, *args):
        if prefix + "wq.weight" in state_dict:
            wq = state_dict.pop(prefix + "wq.weight")
            wk = state_dict.pop(prefix + "wk.weight")
            wv = state_dict.pop(prefix + "wv.weight")
            state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

    def init_kv_cache(self, dtype=None):
        cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim)
        if dtype is None:
            dtype = getattr(torch, self.precision)
        if self.attn_type == "self":
            self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda()
            self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda()

    def forward(
        self,
        x: torch.Tensor,
        rope: RotaryPositionEmbedding,
        input_pos: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        context: Optional[torch.Tensor] = None,
    ):
        """
        Forward pass of GQA.

        Args:
            x: The input tensor of shape (batch_size, seq_len, dim).
            rope: The rotary positional embedding module.
            input_pos: The starting position of the current sequence.
            mask: The attention mask tensor.
            context: The context tensor of shape (batch_size, context_len, dim).

        Returns:
            The output tensor after applying GQA.
        """
        bsz, seqlen, _ = x.shape

        # Use one single module to handle both self-attn and cross-attn
        context = x if context is None else context
        context_len = seqlen if context is None else context.shape[1]

        if self.fuse_qkv:
            q_size = self.n_local_heads * self.head_dim
            kv_size = self.n_local_kv_heads * self.head_dim
            xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
        else:
            # Compute query, key, and value projections
            xq, xk, xv = self.wq(x), self.wk(context), self.wv(context)

        # Reshape projections
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim)

        # QK normalization
        if self.use_qk_normalization:
            xq = self.q_norm(xq)
            xk = self.k_norm(xk)

        # Apply rotary positional embeddings to queries and keys
        # Only apply RoPE to self-attention!
        if self.attn_type in ["self", "full"]:
            xq, xk = rope(xq, xk, input_pos, seqlen)

        xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
        # xq: (bs, n_local_heads, seqlen, head_dim)
        # xk: (bs, n_kv_heads, cache_len + context_len, head_dim)
        # xv: (bs, n_kv_heads, cache_len + context_len, head_dim)
        if self.attn_type == "self":
            # Update cache with current key and value tensors
            assert input_pos is not None
            self.cache_k[:bsz, :, input_pos] = xk
            self.cache_v[:bsz, :, input_pos] = xv
            keys, values = (
                self.cache_k[:bsz, :, :],
                self.cache_v[:bsz, :, :],
            )
        else:
            keys, values = xk, xv

        # Repeat keys and values if necessary
        keys = keys.repeat_interleave(self.n_rep, dim=1)  # (bs, n_local_heads, cache_len + context_len, head_dim)
        values = values.repeat_interleave(self.n_rep, dim=1)  # (bs, n_local_heads, cache_len + context_len, head_dim)

        # For self-attention, `is_causal` should be set to False when KV cache is pre-computed and used,
        # since the masking is handled outside this attention module.
        # For cross-attention, it's always full-attn without causal mask
        is_causal = False
        output = scaled_dot_product_attention(
            xq,
            keys,
            values,
            head_dim=self.head_dim,
            mask=mask,
            is_causal=is_causal,
            dropout_p=0.0,
        )
        output = output.view(bsz, seqlen, -1)
        output = self.wo(output)
        if self.tp_size > 1:
            output = all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group())
        return output


def scaled_dot_product_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    head_dim: int,
    mask: Optional[torch.Tensor] = None,
    is_causal: Optional[bool] = None,
    dropout_p: float = 0.0,
) -> torch.Tensor:
    """
    PyTorch's native implementation of Flash Attention 2.

    If `is_causal` is given, then the causal attention mask is applied accordingly:
    - If `is_causal` is True, the standard upper-left causal attention masking is applied.
    - If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is
      provided (i.e., `mask is not None`).

    If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied
    based on the provided mask tensor:
    - If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True,
    leading to the standard upper-left causal attention masking.
    - If an attention mask is given (i.e., `mask is not None`), the provided mask is used,
    and `is_causal` is set to False.

    Args:
        q (torch.Tensor): Query tensor
        k (torch.Tensor): Key tensor
        v (torch.Tensor): Value tensor
        head_dim (int): Dimension of each attention head
        mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
        is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None.
        dropout_p (float, optional): Dropout rate. Defaults to 0.0.

    Returns:
        torch.Tensor: Output tensor after applying scaled dot-product attention
    """
    scale = 1.0 / math.sqrt(head_dim)
    if is_causal is None:
        is_causal = mask is None
    y = torch.nn.functional.scaled_dot_product_attention(
        q,
        k,
        v,
        attn_mask=mask,
        dropout_p=dropout_p,
        scale=scale,
        is_causal=is_causal,
    )
    return y.transpose(1, 2).contiguous()
