# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable

import torch
from torch.utils import benchmark

from xformers.components.attention.utils import iterative_pinv

MIN_RUN_TIME = 1
SHAPES = [[8, 8], [256, 1024], [128, 256]]
SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99]


def bench_inverse(inverse_fn: Callable[[torch.Tensor], torch.Tensor]):
    min_run_time = MIN_RUN_TIME
    prob = 0.9
    device = torch.device("cuda")
    results = []

    for B, M, K in zip(*SHAPES):
        a = torch.rand(B, M, M, device=device)
        a[a < prob] = 0
        a = torch.softmax(a, dim=-1)

        results.extend(
            [
                benchmark.Timer(
                    stmt=f"{inverse_fn.__name__}(a)",
                    globals={
                        "a": a,
                        f"{inverse_fn.__name__}": inverse_fn,
                    },
                    label=f"{inverse_fn.__name__}",
                    sub_label="dense",
                    description=f"B={B}, M={M}, K={K}",
                ).blocked_autorange(min_run_time=min_run_time),
            ]
        )
        for prob in SPARSITIES:
            a = torch.rand(B, M, M, device=device)
            a[a < prob] = 0
            a = a.to_sparse()
            results.append(
                benchmark.Timer(
                    stmt=f"{inverse_fn.__name__}(a)",
                    globals={
                        "a": a,
                        f"{inverse_fn.__name__}": inverse_fn,
                    },
                    label=f"{inverse_fn.__name__}",
                    sub_label=f"sparsity: {prob:0.2f}",
                    description=f"B={B}, M={M}, K={K}",
                ).blocked_autorange(min_run_time=min_run_time)
            )

    compare = benchmark.Compare(results)
    compare.print()


def iterative_pinv_analysis(
    identity_tolerance: float = 1e-1,
    pinv_tolerance: float = 5e-1,
    max_iters: int = 30,
    plot: bool = True,
):

    for i in range(1, 10):
        B, M = 1, 2**i
        a = torch.rand(B, M, M)
        a = torch.softmax(a, dim=-1)

        for n_iter in range(1, max_iters + 1):
            result = iterative_pinv(a, n_iter=n_iter)
            expected = torch.linalg.pinv(a)

            result_identity = torch.matmul(a, result)
            identity = torch.eye(M)

            # Default is frobenius norm.
            identity_error = torch.linalg.norm(identity - result_identity, dim=(-2, -1))
            inverse_error = torch.linalg.norm(expected - result, dim=(-2, -1))

            if (identity_error < identity_tolerance).all() or n_iter == max_iters:
                print(
                    f"Size {M}, n_iters {n_iter}: \n\t \
                    Final Error from Identity: {identity_error.item()} \n\t \
                    Final Error from linalg.pinv {inverse_error.item()}"
                )
                break


if __name__ == "__main__":
    iterative_pinv_analysis()
    bench_inverse(iterative_pinv)
    bench_inverse(torch.linalg.pinv)
