import torch
import torch.nn.functional as F

from typing import Tuple, Optional
from transformers.utils import is_flash_attn_2_available

if is_flash_attn_2_available():
    from flash_attn import flash_attn_func


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    if n_rep == 1:
        return x
    return x.repeat(1, 1, n_rep, 1)


def reshape_for_boardcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    ndim = x.ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    _xq = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    _xk = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    freqs_cis = reshape_for_boardcast(freqs_cis, _xq)

    xq_out = torch.view_as_real(_xq * freqs_cis).flatten(_xq.ndim - 1)
    xk_out = torch.view_as_real(_xk * freqs_cis).flatten(_xk.ndim - 1)
    return xq_out.type_as(xq), xk_out.type_as(xk)


def _flash_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    dropout: float = 0.0,
    causal: bool = True,
):
    """
    Args:
        q (torch.Tensor): query tensor, shape as [bs, len, n_q_head, head_dim]
        k (torch.Tensor): key tensor,   shape as [bs, len, n_kv_head, head_dim]
        v (torch.Tensor): value tensor, shape as [bs, len, n_kv_head, head_dim]
        dropout (float): dropout rate
        causal (bool): whether to use causal attention or not
    """
    if is_flash_attn_2_available() is not True:
        raise NotImplementedError("Flash Attention 2 is not installed.")

    attn_outputs = flash_attn_func(q, k, v, dropout_p=dropout, causal=causal)
    return attn_outputs


def _sdpa_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    dropout: float = 0.0,
    causal: bool = True,
):
    """
    Args:
        q (torch.Tensor): query tensor, shape as [bs, n_q_head, len, head_dim]
        k (torch.Tensor): key tensor,   shape as [bs, n_kv_head, len, head_dim]
        v (torch.Tensor): value tensor, shape as [bs, n_kv_head, len, head_dim]
        dropout (float): dropout rate
        causal (bool): whether to use causal attention or not
    """
    attn_outputs = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout, is_causal=causal)
    return attn_outputs


def _vanilla_attn_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    dropout: float = 0.0,
    mask: torch.Tensor = None,
):
    """
    Args:
        q (torch.Tensor): query tensor, shape as [bs, len, len, head_dim]
        k (torch.Tensor): key tensor,   shape as [bs, n_kv_head, len, head_dim]
        v (torch.Tensor): value tensor, shape as [bs, n_kv_head, len, head_dim]
    """
    scores = torch.matmul(q, k.transpose(2, 3)) / (k.shape[-1] ** 0.5)
    if mask is not None:
        scores += mask
    scores = F.softmax(scores.float(), dim=-1).type_as(q)
    output = torch.matmul(scores, v)
    return output


def attention_forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    dropout: float = 0.0,
    causal: bool = True,
    attn_impl: str = "vanilla",
    freqs_cis: torch.Tensor = None,
    mask: torch.Tensor = None,
):
    if freqs_cis is not None:
        q, k = apply_rotary_emb(q, k, freqs_cis)
    
    """
    Group Query Attention is:
    Q1 Q2 Q3 Q4 Q5 Q6 Q7 Q8
    K1 K1 K2 K2 K3 K3 K4 K4
    
    Matryoshka Attention is:
    Q1 Q2 Q3 Q4 Q5 Q6 Q7 Q8
    K1 K2 K3 K4 K1 K2 K3 K4
    """
    n_rep = q.shape[2] // k.shape[2]
    k = repeat_kv(k, n_rep)
    v = repeat_kv(v, n_rep)

    if attn_impl == "flash_attn":
        # [bsz, seq, n_head, head_dim]
        output = _flash_attn_forward(q, k, v, dropout=dropout, causal=causal)
    else:

        # [bsz, seq, n_head, head_dim] -> [bsz, n_head, seq, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        if attn_impl == "sdpa":
            output = _sdpa_attn_forward(q, k, v, dropout=dropout, causal=causal)
        else:
            if mask is None and causal:
                mask = create_causal_mask(q.shape[2], q.device)
            output = _vanilla_attn_forward(q, k, v, dropout=dropout, mask=mask)
        
        output = output.transpose(1, 2).contiguous()
    return output
    
    




