# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import itertools
from contextlib import nullcontext
from functools import partial
from typing import Any

import torch
from torch.utils import benchmark
from utils import benchmark_main_helper

import xformers.ops.swiglu_op as xsw

min_run_time = 0.5
device = torch.device("cuda")

SHAPES = [
    # Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
    # ViT-Giant
    (9456, 1536, 2736),
    (4440, 1536, 2736),
    (4728, 1536, 2736),
    # Some smaller shapes as well
    (4728, 1536, 1024),
    # GPT-3 (small)
    (32768, 2048, 5632),
    # Chinchilla
    (32768, 8192, 22016),
]


# OP = xsw._SwiGLUDecomposedOp
# OP = xsw.SwiGLUFusedOp
OP = xsw.SwiGLUPackedFusedOp


def product_dict(**kwargs):
    keys = kwargs.keys()
    vals = kwargs.values()
    for instance in itertools.product(*vals):
        yield dict(zip(keys, instance))


CASES = list(
    product_dict(
        shape=SHAPES,
        dtype=[torch.bfloat16, torch.half, "autocast_half"],
        bias=[True, False],
    )
)

DTYPE2STR = {
    torch.bfloat16: "b16   ",
    torch.half: "f16   ",
    "autocast_half": "f16.ac",
}


def benchmark_swiglu(shape, dtype, bias: bool):
    if dtype == "autocast_half":
        inp_dtype, model_dtype, autocast = torch.float, torch.float, True
    else:
        inp_dtype, model_dtype, autocast = dtype, dtype, False

    x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
    module = (
        xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
        .to(device)
        .to(model_dtype)
    )

    dtype_str = DTYPE2STR.get(dtype, dtype)
    bstr = "bias" if bias else "nobi"
    sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"

    params = module._ordered_params()

    PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n    ' if autocast else ""
    yield benchmark.Timer(
        stmt=f"{PREFIX}fn(x, *args)",
        globals={
            "x": x,
            "args": params,
            "fn": partial(xsw.swiglu, op=OP),
        },
        label="swiglu_fw",
        description=OP.NAME,
        sub_label=sub_label,
    )
    yield benchmark.Timer(
        stmt=f"{PREFIX}fn(x, *args)",
        globals={
            "x": x,
            "args": params,
            "fn": partial(xsw.swiglu, op=xsw.SwiGLUEagerOp),
        },
        label="swiglu_fw",
        description="eager",
        sub_label=sub_label,
    )


def benchmark_swiglu_bw(shape, dtype, bias: bool):
    if dtype == "autocast_half":
        inp_dtype, model_dtype = torch.float, torch.float
        cm: Any = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
    else:
        inp_dtype, model_dtype = dtype, dtype
        cm = nullcontext

    x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
    x.requires_grad_()
    module = (
        xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
        .to(device)
        .to(model_dtype)
    )

    dtype_str = DTYPE2STR.get(dtype, dtype)
    bstr = "bias" if bias else "nobi"
    sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"

    params = module._ordered_params()
    with cm():
        out = xsw.swiglu(x, *params, op=OP)
    grad = torch.zeros_like(out)

    yield benchmark.Timer(
        stmt="out.backward(grad, retain_graph=True)",
        globals={
            "out": out,
            "grad": grad,
        },
        label="swiglu_bw",
        description=OP.NAME,
        sub_label=sub_label,
    )
    del out

    with cm():
        out = xsw.swiglu(x, *params, op=xsw.SwiGLUEagerOp)

    yield benchmark.Timer(
        stmt="out.backward(grad, retain_graph=True)",
        globals={
            "out": out,
            "grad": grad,
        },
        label="swiglu_bw",
        description="eager",
        sub_label=sub_label,
    )


benchmark_main_helper(benchmark_swiglu, CASES, min_run_time=min_run_time)
benchmark_main_helper(benchmark_swiglu_bw, CASES, min_run_time=min_run_time)
