# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from typing import Any, Dict

import torch
import triton

from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components.reversible import ReversibleSequence

SHAPES = [(16384, 32), (2048, 256), (128, 4096)]

DEPTH = [4, 32, 256]


def bench_revnet(backward: bool):
    device = torch.device("cuda")
    bw = "+bw" if backward else ""

    for dtype in [torch.float16, torch.float32]:
        results: Dict[str, Any] = {}

        for B, K in SHAPES:
            for depth in DEPTH:
                f = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
                g = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
                revseq = ReversibleSequence(
                    torch.nn.ModuleList([torch.nn.ModuleList([f, g])] * depth)
                )
                revseq = revseq.to(device=device, dtype=dtype)

                a = torch.rand(
                    1, B, K, device=device, dtype=dtype, requires_grad=backward
                )
                b = torch.rand(
                    1, B, K * 2, device=device, dtype=dtype, requires_grad=backward
                )

                def normal_step():
                    y = a
                    for _ in range(depth):
                        y = y + f(y)
                        y = y + g(y)
                    if backward:
                        torch.norm(y).backward()
                    return y

                def reversible_step():
                    y = revseq(b)
                    if backward:
                        torch.norm(y).backward()
                    return y

                for testcase in [
                    TestCase(normal_step, f"residual - fw{bw}"),
                    TestCase(reversible_step, f"reversible - fw{bw}"),
                ]:
                    time = triton.testing.do_bench(testcase.function)[0]
                    key = f"Batch={B}, Features={K}, Depth={depth}"
                    if key not in results:
                        results[key] = {}

                    results[key][testcase.name] = f"{time:.2f}"

        pretty_print(
            results,
            title=f"\n --- Type: {dtype} --- ",
            units="runtime in ms, lower is better",
        )
        pretty_plot(
            results,
            title=f"RevNet-FW{bw}-{dtype}",
            units="runtime in ms, lower is better",
            dash_key="torch",
        )


for bw in [False, True]:
    bench_revnet(bw)
