# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import random

import torch
from torch.utils import benchmark

from fairseq.modules.multihead_attention import MultiheadAttention

BATCH = [20, 41, 97]
SEQ = 64
EMB = 48
HEADS = 4
DROP = 0.1
DEVICE = torch.device("cuda")
ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float]
KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool]


def _reset_seeds():
    torch.manual_seed(0)
    random.seed(0)


def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int):
    if to_dtype == torch.float:
        mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool)
        return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf"))
    return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype)


def benchmark_multihead_attention(
    label="",
    attn_dtype=torch.uint8,
    key_padding_dtype=torch.uint8,
    add_bias_kv=False,
    add_zero_attn=False,
    static_kv=False,
    batch_size=20,
    embedding=EMB,
    seq_len=SEQ,
    num_heads=HEADS,
):

    results = []
    # device = torch.device("cuda")

    xformers_att_config = '{"name": "scaled_dot_product"}'

    attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len)
    key_padding_mask = _get_mask(
        to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len
    )

    q = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
    k = torch.rand(seq_len, batch_size, embedding, requires_grad=True)
    v = torch.rand(seq_len, batch_size, embedding, requires_grad=True)

    _reset_seeds()

    original_mha = MultiheadAttention(
        embedding,
        num_heads,
        dropout=0.0,
        xformers_att_config=None,
        add_bias_kv=add_bias_kv,
        add_zero_attn=add_zero_attn,
    )

    xformers_mha = MultiheadAttention(
        embedding,
        num_heads,
        dropout=0.0,
        xformers_att_config=xformers_att_config,
        add_bias_kv=add_bias_kv,
        add_zero_attn=add_zero_attn,
    )

    def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
        original_mha(
            query=q,
            key=k,
            value=v,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
            static_kv=static_kv,
        )

    def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv):
        xformers_mha(
            query=q,
            key=k,
            value=v,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
            static_kv=static_kv,
        )

    def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
        output, _ = original_mha(
            query=q,
            key=k,
            value=v,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
            static_kv=static_kv,
        )
        loss = torch.norm(output)
        loss.backward()

    def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv):
        output, _ = xformers_mha(
            query=q,
            key=k,
            value=v,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
            static_kv=static_kv,
        )
        loss = torch.norm(output)
        loss.backward()

    fns = [
        original_bench_fw,
        xformers_bench_fw,
        original_bench_fw_bw,
        xformers_bench_fw_bw,
    ]

    for fn in fns:
        results.append(
            benchmark.Timer(
                stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)",
                globals={
                    "q": q,
                    "k": k,
                    "v": v,
                    "key_padding_mask": key_padding_mask,
                    "attn_mask": attn_mask,
                    "static_kv": static_kv,
                    "fn": fn,
                },
                label="multihead fw + bw",
                sub_label=f"{fn.__name__}",
                description=label,
            ).blocked_autorange(min_run_time=1)
        )

    compare = benchmark.Compare(results)
    compare.print()


def run_benchmarks():
    for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product(
        ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False]
    ):
        label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \
            add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}"
        benchmark_multihead_attention(
            label=label,
            attn_dtype=attn_dtype,
            key_padding_dtype=key_padding_dtype,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
        )


run_benchmarks()
