# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

# needed to register custom ops
import xformers  # noqa: F401
import xformers.components.attention.core
from xformers.components.attention._sputnik_sparse import _csr_to_coo
from xformers.components.attention.core import (
    _broadcast_batch,
    _create_random_sparsity,
    _sparse_bmm,
)

cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


def _baseline_matmul_with_sparse_mask(
    a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
    assert a.ndim == b.ndim
    assert mask.ndim == a.ndim
    assert a.shape[-1] == b.shape[-2]
    assert a.shape[-2] == mask.shape[-2], f"{a.shape}, {mask.shape}"
    assert b.shape[-1] == mask.shape[-1], f"{b.shape}, {mask.shape}"
    assert a.shape[:-2] == b.shape[:-2], f"{a.shape}, {b.shape}"
    assert a.shape[:-2] == mask.shape[:-2], f"{a.shape}, {mask.shape}"
    idxs = mask.indices().unbind()
    b = b.transpose(-2, -1)

    # compute matmul for elements within the mask
    val = (a[idxs[:-2] + (idxs[-2], slice(None))] * b[idxs[:-2] + (idxs[-1], slice(None))]).sum(-1)  # type: ignore

    out_shape = a.shape[:-1] + (b.shape[-2],)
    res = torch.sparse_coo_tensor(torch.stack(idxs), val, out_shape)
    return res


def _baseline_matmul_with_dense_mask(
    a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
    res = a @ b
    res[~mask] = float("-inf")
    return res


def _baseline_sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    # need to use torch.sparse.mm to get gradients wrt sparse matrix a
    # TODO implement this in C++ / CUDA as this is slow!
    out = []
    for ai, bi in zip(a, b):
        out.append(torch.sparse.mm(ai, bi))
    return torch.stack(out, dim=0)


@pytest.mark.parametrize("is_sparse", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
def test_matmul_with_mask(device, contiguous, is_sparse):
    B, L, K = 8, 30, 32
    prob = 0.5
    a = torch.rand(B, L, K, device=device)
    b = torch.rand(B, K, L, device=device)
    if not contiguous:
        a = a.transpose(-2, -1).contiguous().transpose(-2, -1)
        b = b.transpose(-2, -1).contiguous().transpose(-2, -1)
    mask = torch.rand(B, L, L, device=device) > prob

    fn = torch.ops.xformers.matmul_with_mask
    fn_gt = _baseline_matmul_with_dense_mask

    if is_sparse:
        mask = mask.to_sparse()
        fn_gt = _baseline_matmul_with_sparse_mask

    res = fn(a, b, mask)
    res_gt = fn_gt(a, b, mask)

    if is_sparse:
        res = res.to_dense()
        res_gt = res_gt.to_dense()

    assert res.dtype == res_gt.dtype
    assert torch.allclose(res, res_gt)


@pytest.mark.parametrize("is_sparse", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
def test_matmul_with_mask_backward(device, contiguous, is_sparse):
    if device == "cuda" and is_sparse is False:
        # Skip test for now due to bug in torch 1.8
        # See https://github.com/pytorch/pytorch/issues/54975
        # Broken CUDA / torch 1.8 combination, awaiting an update
        return

    B, L, K = 8, 10, 16
    prob = 0.5
    a = torch.rand(B, L, K, device=device, requires_grad=True)
    b = torch.rand(B, K, L, device=device, requires_grad=True)
    if not contiguous:
        a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
        b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
    mask = torch.rand(B, L, L, device=device) > prob

    fn = torch.ops.xformers.matmul_with_mask
    fn_gt = _baseline_matmul_with_dense_mask

    if is_sparse:
        mask = mask.to_sparse()
        fn_gt = _baseline_matmul_with_sparse_mask

    def compute_grads(f):
        out = f(a, b, mask)
        if is_sparse:
            out = out.to_dense()
        out.sum().backward()

    compute_grads(fn)
    grad_a = a.grad.clone()
    grad_b = b.grad.clone()
    a.grad = None
    b.grad = None
    compute_grads(fn_gt)
    assert torch.allclose(grad_a, a.grad)
    assert torch.allclose(grad_b, b.grad)


@pytest.mark.parametrize("device", _devices)
def test_sddmm_sputnik(device):
    B, L, M, K = 8, 30, 16, 32
    prob = 0.5
    a = torch.rand(B, L, K, device=device)
    b = torch.rand(B, M, K, device=device).transpose(-2, -1)
    mask = _create_random_sparsity(
        torch.ones(B, L, M, dtype=torch.bool, device=device), prob
    )

    mask_csr = xformers.components.attention.core.SparseCS(mask, device)

    fn = xformers.components.attention.core._matmul_with_mask

    mask = mask.to_sparse()

    res = fn(a, b, mask_csr)
    res_gt = fn(a, b, mask)

    res = res.to_dense()
    res_gt = res_gt.to_dense()

    assert res.dtype == res_gt.dtype
    assert torch.allclose(res, res_gt)


@cuda_only
@pytest.mark.parametrize("prob", [0.5, 1])
@pytest.mark.parametrize("K", [32, 17])
@pytest.mark.parametrize("M", [30, 17])
@pytest.mark.parametrize("L", [30, 17])
def test_sddmm_csr(L, M, K, prob):
    device = torch.device("cuda")
    # TODO add more checks for different nnz
    B = 8
    a = torch.rand(B, L, K, device=device)
    b = torch.rand(B, M, K, device=device)
    mask = _create_random_sparsity(
        torch.ones(B, L, M, dtype=torch.bool, device=device), prob
    )

    mask_csr = xformers.components.attention.core.SparseCS(mask, device)
    row_indices = mask_csr.row_indices
    row_offsets = mask_csr.row_offsets
    column_indices = mask_csr.column_indices

    fn = torch.ops.xformers.csr_sddmm
    fn_gt = torch.ops.xformers.sddmm_sputnik

    res = fn(a, b, row_indices, row_offsets, column_indices)
    res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices)

    assert res.dtype == res_gt.dtype
    assert torch.allclose(res, res_gt, atol=1e-6)


@cuda_only
@pytest.mark.parametrize("nnz", [0, 4, 16, 20, 36])
def test_sddmm_csr_per_nnz(nnz):
    device = torch.device("cuda")
    B = 8
    L, M, K = 1024, 1024, 32
    a = torch.rand(B, L, K, device=device)
    b = torch.rand(B, M, K, device=device)
    mask = torch.zeros(L, M, dtype=torch.bool, device=device)
    mask.view(-1)[: nnz - 1] = True
    mask[-1, -1] = True

    mask_csr = xformers.components.attention.core.SparseCS(mask, device)
    row_indices = mask_csr.row_indices
    row_offsets = mask_csr.row_offsets
    column_indices = mask_csr.column_indices

    fn = torch.ops.xformers.csr_sddmm
    fn_gt = torch.ops.xformers.sddmm_sputnik

    res = fn(a, b, row_indices, row_offsets, column_indices)
    res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices)

    assert res.dtype == res_gt.dtype
    assert torch.allclose(res, res_gt, atol=1e-6)


@cuda_only
@pytest.mark.parametrize("prob", [0.5, 1])
@pytest.mark.parametrize("K", [32, 17])
@pytest.mark.parametrize("M", [30, 17])
@pytest.mark.parametrize("L", [30, 17])
def test_sddmm_coo(L, M, K, prob):
    device = torch.device("cuda")
    # TODO add more checks for different nnz
    B = 8
    a = torch.rand(B, L, K, device=device)
    b = torch.rand(B, M, K, device=device)
    mask = _create_random_sparsity(
        torch.ones(B, L, M, dtype=torch.bool, device=device), prob
    )

    mask_csr = xformers.components.attention.core.SparseCS(mask, device)
    row_indices = mask_csr.row_indices
    row_offsets = mask_csr.row_offsets
    column_indices = mask_csr.column_indices

    fn = torch.ops.xformers.coo_sddmm
    fn_gt = torch.ops.xformers.sddmm_sputnik

    # convert from csr to coo
    row_coo, _ = _csr_to_coo(L, M, row_offsets, column_indices)

    res = fn(a, b, row_indices, row_coo, column_indices)
    res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices)

    assert res.dtype == res_gt.dtype
    assert torch.allclose(res, res_gt, atol=1e-6)


@pytest.mark.parametrize("device", _devices)
def test_sddmm_sputnik_backward(device):
    contiguous = True

    B, L, M, K = 8, 10, 16, 32
    prob = 0.5
    a = torch.rand(B, L, K, device=device, requires_grad=True)
    b = torch.rand(B, M, K, device=device).transpose(-2, -1).requires_grad_(True)
    if not contiguous:
        a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
        b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
    mask = _create_random_sparsity(
        torch.ones(B, L, M, dtype=torch.bool, device=device), prob
    )

    mask_csr = xformers.components.attention.core.SparseCS(mask, device)

    fn = xformers.components.attention.core._matmul_with_mask

    mask = mask.to_sparse()

    out_csr = fn(a, b, mask_csr)
    out_csr.values.sum().backward()
    grad_a = a.grad.clone()
    grad_b = b.grad.clone()
    a.grad = None
    b.grad = None
    # fn(a[None], b[None], mask).coalesce().values().sum().backward()  # TODO check why this fails
    fn(a, b, mask).to_dense().sum().backward()
    assert torch.allclose(grad_a, a.grad, atol=1e-7)
    assert torch.allclose(grad_b, b.grad, atol=1e-7)


@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax_sputnik(device):
    B, L = 8, 30
    prob = 0.5
    a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)

    a_csr = xformers.components.attention.core.SparseCS(a, device)

    fn = xformers.components.attention.core._softmax

    a = a.to_sparse()

    res = fn(a_csr)
    res_gt = fn(a)

    res = res.to_dense()
    res_gt = res_gt.to_dense()

    assert res.dtype == res_gt.dtype
    assert torch.allclose(res, res_gt)


@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax_sputnik_backward(device):
    B, L = 8, 30
    prob = 0.5
    a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)

    a_csr = xformers.components.attention.core.SparseCS(a, device)

    fn = xformers.components.attention.core._softmax

    a = a.to_sparse()

    a_csr.values.requires_grad_(True)
    fn(a_csr).values.sum().backward()
    grad_a = a_csr.values.grad.clone()
    a.requires_grad_(True)
    fn(a).coalesce().values().sum().backward()
    assert torch.allclose(
        grad_a, a.grad.coalesce().values().reshape_as(grad_a), atol=1e-7
    )


@pytest.mark.parametrize("device", _devices)
def test_spmm_sputnik(device):
    B, L, K = 8, 30, 32
    prob = 0.5

    a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)

    b = torch.rand(B, L, K, device=device)

    a_csr = xformers.components.attention.core.SparseCS(a, device)

    fn = xformers.components.attention.core.bmm

    a = a.to_sparse()

    res = fn(a_csr, b)
    res_gt = fn(a, b)

    res = res
    res_gt = res_gt

    assert res.dtype == res_gt.dtype
    assert torch.allclose(res, res_gt)


@pytest.mark.parametrize("device", _devices)
def test_spmm_sputnik_backward(device):
    B, M, L, K = 8, 16, 30, 32
    prob = 0.5

    a = _create_random_sparsity(torch.rand(B, M, L, device=device), prob)

    b = torch.rand(B, L, K, device=device)
    b.requires_grad_(True)

    a_csr = xformers.components.attention.core.SparseCS(a, device)

    fn = xformers.components.attention.core.bmm

    a = a.to_sparse()
    a.requires_grad_(True)
    a_csr.values.requires_grad_(True)

    fn(a_csr, b).sum().backward()
    grad_a = a_csr.values.grad.clone()
    grad_b = b.grad.clone()

    b.grad = None
    fn(a, b).sum().backward()

    assert torch.allclose(
        grad_a, a.grad.coalesce().values().reshape_as(grad_a), atol=1e-7
    )
    assert torch.allclose(grad_b, b.grad, atol=1e-7)


@cuda_only
def test_csr_transpose():
    B, L, K = 8, 30, 40
    prob = 0.5
    device = torch.device("cuda")

    a = _create_random_sparsity(torch.rand(B, L, K, device=device), prob)

    a_csr = xformers.components.attention.core.SparseCS(a, device)

    res = a_csr.transpose()
    res2 = res.transpose()

    assert torch.allclose(res.to_dense(), a.transpose(-2, -1))
    assert torch.allclose(res2.to_dense(), a)


@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("prob", [0.95, 0.996])  # cover > 0.995
@pytest.mark.parametrize("N", [32, 64, 96])  # cover > 64
def test_sparse_bmm(device, contiguous, prob, N):
    B, M = 8, 64
    a = torch.rand(B, M, N, device=device)
    a[a < prob] = 0
    a = a.to_sparse()
    b = torch.rand(B, N, M, device=device)
    if not contiguous:
        a = a + a
        b = b.transpose(-2, -1).contiguous().transpose(-2, -1)

    res = _sparse_bmm(a, b)
    res_gt = _baseline_sparse_bmm(a, b)

    assert torch.allclose(res, res_gt)


@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
def test_sparse_bmm_backward(device, contiguous):
    if device == "cuda":
        # Skip test for now due to bug in torch 1.8
        # See https://github.com/pytorch/pytorch/issues/54975
        # Broken CUDA / torch 1.8 combination, awaiting an update
        return

    B, L, K = 8, 10, 16
    prob = 0.5
    a = torch.rand(B, L, K, device=device)
    a[a < prob] = 0
    a = a.to_sparse()
    b = torch.rand(B, K, L, device=device, requires_grad=True)
    if not contiguous:
        a = a + a
        b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
    a.requires_grad_(True)

    def compute_grads(f):
        out = f(a, b)
        out.sum().backward()

    compute_grads(_sparse_bmm)
    grad_a = a.grad.clone().coalesce()
    grad_b = b.grad.clone()
    a.grad = None
    b.grad = None
    compute_grads(_baseline_sparse_bmm)
    new_grad_a = a.grad.coalesce()
    assert torch.allclose(grad_a.indices(), new_grad_a.indices())
    assert torch.allclose(grad_a.values(), new_grad_a.values())
    assert torch.allclose(grad_b, b.grad)


@pytest.mark.parametrize("device", _devices)
def test_sparse_coo_broadcast(device):
    B, L, K = 8, 10, 16
    prob = 0.5
    a = torch.rand(L, K, device=device)
    a[a < prob] = 0

    a_sparse = a.to_sparse()

    res = _broadcast_batch(a_sparse, B)

    res_gt = a[None, :, :].expand(B, L, K)

    assert torch.allclose(res.to_dense(), res_gt)
