# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

# needed to register custom ops
import xformers  # noqa: F401
from xformers.ops import masked_matmul
from xformers.sparse import BlockSparseTensor, SparseCSRTensor

cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"]
_tensor_types = [BlockSparseTensor, SparseCSRTensor]


def _create_blocksparse_tensor(
    device, block_size=32, Z=8, C=2, H=64, W=64, dtype=torch.float32
):
    layout = torch.randint(2, (C, H // block_size, W // block_size), device=device)
    layout[:, :, 0] = 1
    layout[:, 0, :] = 1
    values = torch.randn(Z, layout.sum(), block_size, block_size, device=device).to(
        dtype
    )

    return BlockSparseTensor(values, layout)


def _create_csr_tensor(device, dtype, shape, sparsity, divisible_by=4):
    matrix = torch.rand(shape, dtype=torch.float32, device=device).to(dtype)
    assert matrix.ndim == 3
    keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity
    nonzero = torch.nonzero(keep)
    nnz = nonzero.shape[0]
    # NOTE: need to make it a multiple of 4 for sputnik
    nonzero = nonzero[: (nnz - nnz % divisible_by)]
    i, j = nonzero.unbind(1)
    output = torch.zeros_like(matrix)
    bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None]
    output[bdim, i, j] = matrix[bdim, i, j]
    return SparseCSRTensor.from_dense(output)


def _create_tensor(tensor_type, device, dtype, shape, sparsity):
    if tensor_type == BlockSparseTensor:
        block_size = 16
        return _create_blocksparse_tensor(
            device=device, dtype=dtype, block_size=block_size
        )
    elif tensor_type == SparseCSRTensor:
        return _create_csr_tensor(
            device=device, dtype=dtype, shape=shape, sparsity=sparsity
        )


def _seed():
    torch.random.manual_seed(42)
    torch.cuda.manual_seed_all(42)


def _get_dtype_atol(tensor_type, device: str):
    _seed()

    if tensor_type == BlockSparseTensor and "cuda" in device:
        # Upstream GPU blocksparse (Triton op) uses TF32 by default for all internal computations
        # TF32 has the precision of fp16 but the range of fp32
        # See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True  # type: ignore
        return torch.float32, 1e-1

    # Force pytorch to keep its computations as float32 (will default to tf32 with recent cuda and ampere+ GPU)
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False  # type: ignore

    return torch.float32, 1e-5


@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("func", [torch.add, torch.mul])
def test_sparse_binary_ops(func, device):
    # TODO: add for BlockSparseTensor as well
    N, H, W = 8, 64, 64
    sparsity = 0.5
    shape = (N, H, W)

    a_sparse = _create_tensor(
        SparseCSRTensor, device, dtype=torch.float32, shape=shape, sparsity=sparsity
    )
    a = a_sparse.to_dense()

    b = a
    b_sparse = a_sparse

    res = func(a_sparse, b_sparse).to_dense()
    res_gt = func(a, b)

    assert torch.allclose(res, res_gt)


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_masked_matmul(tensor_type, device):
    N, C, H, W, L = 8, 2, 64, 64, 32
    sparsity = 0.7
    dtype, atol = _get_dtype_atol(tensor_type, device)

    shape0 = (N, C, H, W)
    shape1 = (N, C, H, L)
    shape2 = (N, C, W, L)

    if tensor_type != BlockSparseTensor:
        shape0 = shape0[1:]
        shape1 = shape1[1:]
        shape2 = shape2[1:]

    mask_sparse = _create_tensor(
        tensor_type, device, dtype=torch.bool, shape=shape0, sparsity=sparsity
    )
    mask = mask_sparse.to_dense()

    a = torch.randn(shape1, device=device, dtype=dtype)
    b = torch.randn(shape2, device=device, dtype=dtype)

    aa = a.clone()
    bb = b.clone()

    a.requires_grad_(True)
    b.requires_grad_(True)
    aa.requires_grad_(True)
    bb.requires_grad_(True)

    bt = b.transpose(-2, -1)
    bbt = bb.transpose(-2, -1)

    res_gt = masked_matmul(a, bt, mask)
    res = masked_matmul(aa, bbt, mask_sparse)

    res_dense = res.to_dense()
    res_dense = torch.where(mask, res_dense, torch.full_like(res_dense, float("-inf")))

    assert res.dtype == res_gt.dtype
    assert torch.allclose(res_dense, res_gt, atol=atol)

    # try to workaround non-contiguous issues with triton for now
    res_gt.backward(torch.ones_like(res_gt))
    res.values().backward(torch.ones_like(res.values()))

    assert torch.allclose(a.grad, aa.grad, atol=atol)
    assert torch.allclose(b.grad, bb.grad, atol=atol)


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_bmm(tensor_type, device):
    N, C, H, W, L = 8, 2, 64, 64, 32
    dtype, atol = _get_dtype_atol(tensor_type, device)

    sparsity = 0.8
    shape0 = (N, C, H, W)
    shape1 = (N, C, W, L)

    if tensor_type != BlockSparseTensor:
        shape0 = shape0[1:]
        shape1 = shape1[1:]

    a_sparse = _create_tensor(
        tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity
    )
    a = a_sparse.to_dense()
    mask = a != 0

    a_sparse.requires_grad_(True)
    a.requires_grad_(True)

    b = torch.randn(shape1, device=device, dtype=dtype)
    b2 = b.clone()

    b.requires_grad_(True)
    b2.requires_grad_(True)

    res_gt = a @ b
    res = a_sparse @ b2

    assert res.dtype == res_gt.dtype
    assert torch.allclose(
        res, res_gt, atol=atol
    ), f"{torch.max(torch.abs(res-res_gt))} - tolerance: {atol}"

    res_gt.sum().backward()
    res.sum().backward()

    a_grad = a.grad.clone().detach()
    a_grad[~mask] = 0

    assert torch.allclose(b.grad, b2.grad, atol=atol)
    assert torch.allclose(
        a_grad, a_sparse.grad.to_dense(), atol=atol
    ), f"{torch.max(torch.abs(a_grad-a_sparse.grad.to_dense()))}"


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax(tensor_type, device):
    N, C, H, W = 8, 2, 64, 64
    dtype, atol = _get_dtype_atol(tensor_type, device)

    sparsity = 0.8

    shape0 = (N, C, H, W)
    if tensor_type != BlockSparseTensor:
        shape0 = shape0[1:]

    a_sparse = _create_tensor(
        tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity
    )
    a = a_sparse.to_dense()
    mask = a != 0

    a[~mask] = float("-inf")

    a_sparse.requires_grad_(True)
    a.requires_grad_(True)

    res_gt = torch.softmax(a, dim=-1)
    res_sparse = torch.softmax(a_sparse, dim=-1)

    res = res_sparse.to_dense()

    assert res.dtype == res_gt.dtype
    assert torch.allclose(
        res, res_gt, atol=atol
    ), f"{torch.max(torch.abs(res- res_gt))}"

    # WARNING: gradients are modified in-place!
    res_sparse.values().backward(torch.ones_like(res_sparse.values()))
    res_gt.backward(torch.ones_like(res_gt))

    a_grad = a.grad.clone()
    a_grad[~mask] = 0

    assert torch.allclose(
        a_grad, a_sparse.grad.to_dense(), atol=atol
    ), f"{torch.max(torch.abs(a_grad- a_sparse.grad.to_dense()))}"


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_deepcopy(tensor_type, device):
    import copy

    N, C, H, W = 8, 2, 64, 64
    dtype = torch.float32
    sparsity = 0.8

    shape0 = (N, C, H, W)
    if tensor_type != BlockSparseTensor:
        shape0 = shape0[1:]

    a_sparse = _create_tensor(
        tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity
    )

    b_sparse = copy.deepcopy(a_sparse)
    assert torch.equal(a_sparse, b_sparse)


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_module_buffer(tensor_type, device):
    N, C, H, W = 8, 2, 64, 64
    dtype = torch.float32
    sparsity = 0.8

    shape0 = (N, C, H, W)
    if tensor_type != BlockSparseTensor:
        shape0 = shape0[1:]

    a_sparse = _create_tensor(
        tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity
    )
    b_sparse = _create_tensor(
        tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity
    )

    module = torch.nn.Module()
    # test that register_buffer works
    module.register_buffer("a_sparse", a_sparse)

    assert module.a_sparse is a_sparse

    module.to(device)
    assert module.a_sparse.device == torch.device(device)

    state_dict = module.state_dict()
    assert "a_sparse" in state_dict
    assert torch.equal(a_sparse.to(device), state_dict["a_sparse"])

    module.load_state_dict(state_dict)

    module.load_state_dict({"a_sparse": b_sparse})
    assert torch.equal(module.a_sparse, b_sparse.to(device))
