# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import math
from typing import Optional

import pytest
import torch

from llmfoundry.models.layers.attention import (
    attn_bias_shape,
    build_attn_bias,
    check_alibi_support,
    flash_attn_fn,
    gen_slopes,
    is_flash_v2_installed,
    scaled_multihead_dot_product_attention,
)
from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info


@pytest.mark.gpu
@pytest.mark.skipif(
    not is_flash_v2_installed(),
    reason='GQA natively only supported by Flash Attention after v2.',
)
@pytest.mark.parametrize('kv_n_heads', [1, 4, 8])
def test_gqa_kv_repetition(kv_n_heads: int):
    # Test that flash attention v2 with GQA (kv_n_heads < n_heads) works the same
    # whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own.
    d = 128
    n_heads = 8
    seqlen_1 = 6
    bsz = 2

    query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda()
    query_1.requires_grad = True
    key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda()
    key_1.requires_grad = True
    value_1 = torch.randn(bsz, seqlen_1,
                          kv_n_heads * d).to(torch.bfloat16).cuda()
    value_1.requires_grad = True

    output_1, _, _ = flash_attn_fn(
        query=query_1,
        key=key_1,
        value=value_1,
        n_heads=n_heads,
        kv_n_heads=kv_n_heads,
        past_key_value=None,
        softmax_scale=1 / math.sqrt(d),
        attn_bias=None,
        key_padding_mask=None,
        is_causal=True,
        dropout_p=0.0,
        training=False,
        needs_weights=False,
        flash_attn_padding_info=gen_flash_attn_padding_info(
            bsz,
            seqlen_1,
            0,
            query_1.device,
            None,
            None,
        ),
        should_repeat_kv_for_gqa=True,
    )

    output_1.sum().backward()

    query_2 = query_1.detach().clone()
    query_2.requires_grad = True
    key_2 = key_1.detach().clone()
    key_2.requires_grad = True
    value_2 = value_1.detach().clone()
    value_2.requires_grad = True

    output_2, _, _ = flash_attn_fn(
        query=query_2,
        key=key_2,
        value=value_2,
        n_heads=n_heads,
        kv_n_heads=kv_n_heads,
        past_key_value=None,
        softmax_scale=1 / math.sqrt(d),
        attn_bias=None,
        key_padding_mask=None,
        is_causal=True,
        dropout_p=0.0,
        training=False,
        needs_weights=False,
        flash_attn_padding_info=gen_flash_attn_padding_info(
            bsz,
            seqlen_1,
            0,
            query_2.device,
            None,
            None,
        ),
        should_repeat_kv_for_gqa=False,
    )

    output_2.sum().backward()
    assert torch.allclose(output_1, output_2)
    assert torch.allclose(query_1.grad, query_2.grad)  # type: ignore
    assert torch.allclose(key_1.grad, key_2.grad)  # type: ignore
    assert torch.allclose(value_1.grad, value_2.grad)  # type: ignore


@pytest.mark.gpu
@pytest.mark.skipif(
    not is_flash_v2_installed(v2_version='v2.1.2'),
    reason=
    'Using sequence id with flash attention requires flash attention v2.1.2 or higher.',
)
def test_seq_id_masking_FA_v2():
    # Test that flash attention v2 with sequence id masking works correctly.
    d = 128
    n_heads = 4
    kv_n_heads = 4
    seqlen_1 = 6
    bsz = 2

    query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda()
    query_1.requires_grad = True
    key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda()
    key_1.requires_grad = True
    value_1 = torch.randn(bsz, seqlen_1,
                          kv_n_heads * d).to(torch.bfloat16).cuda()
    value_1.requires_grad = True

    seq_ranges = [
        (0, 3),
        (3, 5),
        (5, 6),
    ]  # Each batch has 3 sequences of length 3, 2, and 1 respectively.
    attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0],
                                               [3, 2, 1, 0, 0,
                                                0]]).to(torch.int64).cuda()

    flash_attn_padding_info_1 = gen_flash_attn_padding_info(
        bsz,
        seqlen_1,
        0,
        query_1.device,
        attention_mask_in_length_1,
        None,
    )

    output_1, _, _ = flash_attn_fn(
        query=query_1,
        key=key_1,
        value=value_1,
        n_heads=n_heads,
        kv_n_heads=kv_n_heads,
        past_key_value=None,
        softmax_scale=1 / math.sqrt(d),
        attn_bias=None,
        key_padding_mask=None,
        is_causal=True,
        dropout_p=0.0,
        training=False,
        needs_weights=False,
        flash_attn_padding_info=flash_attn_padding_info_1,
    )

    output_1.sum().backward()

    for seq_range in seq_ranges:
        query_2 = query_1.detach().clone()[:, seq_range[0]:seq_range[1], :]
        query_2.requires_grad = True
        key_2 = key_1.detach().clone()[:, seq_range[0]:seq_range[1], :]
        key_2.requires_grad = True
        value_2 = value_1.detach().clone()[:, seq_range[0]:seq_range[1], :]
        value_2.requires_grad = True

        flash_attn_padding_info_2 = gen_flash_attn_padding_info(
            bsz,
            seq_range[1] - seq_range[0],
            0,
            query_2.device,
            None,
            None,
        )

        output_2, _, _ = flash_attn_fn(
            query=query_2,
            key=key_2,
            value=value_2,
            n_heads=n_heads,
            kv_n_heads=kv_n_heads,
            past_key_value=None,
            softmax_scale=1 / math.sqrt(d),
            attn_bias=None,
            key_padding_mask=None,
            is_causal=True,
            dropout_p=0.0,
            training=False,
            needs_weights=False,
            flash_attn_padding_info=flash_attn_padding_info_2,
        )

        output_2.sum().backward()
        assert torch.allclose(
            output_1[:, seq_range[0]:seq_range[1], :],
            output_2,
        )
        assert torch.allclose(
            query_1.grad[:, seq_range[0]:seq_range[1], :],  # type: ignore
            query_2.grad,  # type: ignore
        )
        assert torch.allclose(
            key_1.grad[:, seq_range[0]:seq_range[1], :],  # type: ignore
            key_2.grad,  # type: ignore
        )
        assert torch.allclose(
            value_1.grad[:, seq_range[0]:seq_range[1], :],  # type: ignore
            value_2.grad,  # type: ignore
        )


@pytest.mark.gpu
@pytest.mark.skipif(
    not check_alibi_support('flash'),
    reason='ALiBi only supported by Flash Attention after v2.4.2.',
)
@pytest.mark.parametrize('n_heads', [1, 6, 8])
def test_alibi_bias(n_heads: int):
    # Test that sliding window attention works as expected.
    dtype = torch.bfloat16
    device = 'cuda'
    d = 128
    seqlen_1 = 8
    bsz = 2

    query_1 = torch.randn(bsz, seqlen_1,
                          n_heads * d).to(dtype=dtype, device=device)
    query_1.requires_grad = True
    key_1 = torch.randn(bsz, seqlen_1,
                        n_heads * d).to(dtype=dtype, device=device)
    key_1.requires_grad = True
    value_1 = torch.randn(bsz, seqlen_1,
                          n_heads * d).to(dtype=dtype, device=device)
    value_1.requires_grad = True
    alibi_slopes_1 = gen_slopes(
        n_heads=n_heads,
        alibi_bias_max=8,
        device=torch.device(device),
        return_1d=True,
    )
    output_1, _, _ = flash_attn_fn(
        query=query_1,
        key=key_1,
        value=value_1,
        n_heads=n_heads,
        kv_n_heads=n_heads,
        past_key_value=None,
        softmax_scale=1 / math.sqrt(d),
        attn_bias=None,
        key_padding_mask=None,
        is_causal=True,
        dropout_p=0.0,
        training=False,
        needs_weights=False,
        flash_attn_padding_info=gen_flash_attn_padding_info(
            bsz,
            seqlen_1,
            0,
            query_1.device,
            None,
            None,
        ),
        should_repeat_kv_for_gqa=True,
        alibi_slopes=alibi_slopes_1,
    )

    output_1.sum().backward()

    query_2 = query_1.detach().clone()
    query_2.requires_grad = True
    key_2 = key_1.detach().clone()
    key_2.requires_grad = True
    value_2 = value_1.detach().clone()
    value_2.requires_grad = True

    def gen_bias():
        causal = True
        bs = attn_bias_shape(
            'torch',
            n_heads,
            seqlen_1,
            True,
            use_sequence_id=False,
            causal=causal,
        )

        attn_bias = torch.zeros(*bs, device=device)
        attn_bias = build_attn_bias(
            'torch',
            attn_bias,
            n_heads,
            seqlen_1,
            causal=causal,
            alibi=True,
            alibi_bias_max=8,
        )
        return attn_bias

    attn_bias_2 = gen_bias()

    output_2, _, _ = scaled_multihead_dot_product_attention(
        query=query_2,
        key=key_2,
        value=value_2,
        n_heads=n_heads,
        kv_n_heads=n_heads,
        past_key_value=None,
        softmax_scale=1 / math.sqrt(d),
        attn_bias=attn_bias_2,
        key_padding_mask=None,
        is_causal=True,
        dropout_p=0.0,
        training=False,
        needs_weights=False,
    )

    output_2.sum().backward()

    _assert_approx_equal(output_1, output_2)
    assert (query_2.grad is not None) and (query_1.grad is not None)
    _assert_approx_equal(query_1.grad, query_2.grad)
    assert (key_2.grad is not None) and (key_1.grad is not None)
    _assert_approx_equal(key_1.grad, key_2.grad)
    assert (value_2.grad is not None) and (value_1.grad is not None)
    _assert_approx_equal(value_1.grad, value_2.grad)


@pytest.mark.gpu
@pytest.mark.skipif(
    not is_flash_v2_installed(v2_version='v2.6.2'),
    reason=
    'attn_logit_softcapping only supported by Flash Attention after v2.6.2.',
)
@pytest.mark.parametrize(
    'attn_logit_softcapping',
    [None, 0.1, 1.0, 10.0, 100.0],
)
def test_attn_logit_softcapping(attn_logit_softcapping: Optional[float]):
    # Test that attn_logit_softcapping in attention works as expected.
    dtype = torch.bfloat16
    device = 'cuda'
    d = 128
    seqlen_1 = 8
    bsz = 2
    n_heads = 4

    query_1 = torch.randn(bsz, seqlen_1,
                          n_heads * d).to(dtype=dtype, device=device)
    query_1.requires_grad = True
    key_1 = torch.randn(bsz, seqlen_1,
                        n_heads * d).to(dtype=dtype, device=device)
    key_1.requires_grad = True
    value_1 = torch.randn(bsz, seqlen_1,
                          n_heads * d).to(dtype=dtype, device=device)
    value_1.requires_grad = True
    output_1, _, _ = flash_attn_fn(
        query=query_1,
        key=key_1,
        value=value_1,
        n_heads=n_heads,
        kv_n_heads=n_heads,
        past_key_value=None,
        softmax_scale=1 / math.sqrt(d),
        attn_bias=None,
        key_padding_mask=None,
        is_causal=True,
        dropout_p=0.0,
        training=False,
        needs_weights=False,
        flash_attn_padding_info=gen_flash_attn_padding_info(
            bsz,
            seqlen_1,
            0,
            query_1.device,
            None,
            None,
        ),
        should_repeat_kv_for_gqa=True,
        attn_logit_softcapping=attn_logit_softcapping,
    )
    output_1.sum().backward()

    query_2 = query_1.detach().clone()
    query_2.requires_grad = True
    key_2 = key_1.detach().clone()
    key_2.requires_grad = True
    value_2 = value_1.detach().clone()
    value_2.requires_grad = True
    output_2, _, _ = scaled_multihead_dot_product_attention(
        query=query_2,
        key=key_2,
        value=value_2,
        n_heads=n_heads,
        kv_n_heads=n_heads,
        past_key_value=None,
        softmax_scale=1 / math.sqrt(d),
        key_padding_mask=None,
        is_causal=True,
        dropout_p=0.0,
        training=False,
        needs_weights=False,
        attn_logit_softcapping=attn_logit_softcapping,
    )
    output_2.sum().backward()

    _assert_approx_equal(output_1, output_2)
    assert (query_2.grad is not None) and (query_1.grad is not None)
    _assert_approx_equal(query_1.grad, query_2.grad)
    assert (key_2.grad is not None) and (key_1.grad is not None)
    _assert_approx_equal(key_1.grad, key_2.grad)
    assert (value_2.grad is not None) and (value_1.grad is not None)
    _assert_approx_equal(value_1.grad, value_2.grad)


def _assert_approx_equal(
    value1: torch.Tensor,
    value2: torch.Tensor,
    atol: float = 1e-2,
    rtol: float = 1e-2,
):
    actual_difference = torch.norm(value2 - value1)
    allowed_difference = atol + rtol * torch.norm(value2)
    assert actual_difference < allowed_difference, f'{actual_difference=}, {allowed_difference=}'
