import torch
import dspattn
import nvtx
import argparse

parser = argparse.ArgumentParser(description="Verify/Profile the SDDMM kernel")
parser.add_argument('--batch_size', '-b', type=int, default=1, help='batch size')
parser.add_argument('--sequence_length', '-l', type=int, default=4096, help='sequence length')
parser.add_argument('--embedding', '-k', type=int, default=64, help='embedding dimension')
parser.add_argument('--dtype', '-t', choices=['float', 'bfloat16'], help='data type of operand')
parser.add_argument('--mode', '-m', choices=['verify', 'profile'], default='profile', help='verify the correctness or do profiling')
parser.add_argument('--mask', '-x', action='store_true', help="add a float mask")
parser.add_argument('--train', '-r', action='store_true', help="check the training script")
args = parser.parse_args()

if args.dtype == 'float':
    dtype = torch.float32
else:
    dtype = torch.bfloat16

if args.batch_size > 1:
    lhs_matrix = torch.randn(size=(args.batch_size, args.sequence_length, args.embedding), dtype=dtype, device='cuda')
    rhs_matrix = torch.randn(size=(args.batch_size, args.sequence_length, args.embedding), dtype=dtype, device='cuda')
    if args.mask:
        prob = torch.ones(size=(args.batch_size, 1, args.sequence_length), dtype=dtype, device='cuda') * 0.8
        mask = torch.bernoulli(prob) * -1e16
        # mask = torch.randn(size=(args.batch_size, 1, args.sequence_length), dtype=dtype, device='cuda')
    else: mask = None
else:
    lhs_matrix = torch.randn(size=(args.sequence_length, args.embedding), dtype=dtype, device='cuda')
    rhs_matrix = torch.randn(size=(args.sequence_length, args.embedding), dtype=dtype, device='cuda')
    if args.mask:
        prob = torch.ones(size=(1, args.sequence_length), dtype=dtype, device='cuda') * 0.8
        mask = torch.bernoulli(prob) * -1e16
    else: mask = None


def verify(lhs_matrix, rhs_matrix, mask):
    if mask is not None:
        mp = 'true'
    else:
        mp = 'false'
    print("= Function Verification =")
    print("(B, M, N, K): (%d, %d, %d, %d), dtype: %s, mask: %s" % (args.batch_size, args.sequence_length, args.sequence_length, args.embedding, args.dtype, mp))
    nonzeros_sddmm, metadata_sddmm = dspattn.sddmm(lhs_matrix, rhs_matrix, mask, args.train)
    if lhs_matrix.dim() == 2:
        dense_matrix = torch.matmul(lhs_matrix, torch.transpose(rhs_matrix, 0, 1))
    else:
        dense_matrix = torch.bmm(lhs_matrix, torch.transpose(rhs_matrix, 1, 2))
    if mask is not None:
        dense_matrix += mask
    nonzeros_ref, metadata_ref, uncompressed_ref = dspattn.dense2sparse(dense_matrix)

    if args.train:
        error_nnz = torch.abs(nonzeros_sddmm - uncompressed_ref)
        error_meta = torch.zeros(1)
    else:
        error_nnz = torch.abs(nonzeros_sddmm - nonzeros_ref)
        error_meta = torch.abs(metadata_sddmm - metadata_ref)

    print(
        "Totally %d (%.4f) mismatchs in the nonzeros, the maximum error is %.2f" % (
        torch.where(error_nnz > 1e-5)[0].numel(), 
        torch.where(error_nnz > 1e-5)[0].numel() / nonzeros_sddmm.numel(),
        float(torch.max(error_nnz))
        ))
    print(
        "Totally %d (%.4f) mismatchs in the meta data" % (
        torch.where(error_meta > 0)[0].numel(), 
        torch.where(error_meta > 0)[0].numel() / metadata_ref.numel()
        ))


def profile(lhs_matrix, rhs_matrix, mask):
    if lhs_matrix.dim() == 2:
        rhs_matrix_t = torch.transpose(rhs_matrix, 0, 1)
        for i in range(10):
            with nvtx.annotate("torch.matmul"):
                dense_matrix = torch.matmul(lhs_matrix, rhs_matrix_t)
                if mask is not None: dense_matrix += mask
    else:
        rhs_matrix_t = torch.transpose(rhs_matrix, 1, 2)
        for i in range(10):
            with nvtx.annotate("torch.bmm"):
                dense_matrix = torch.bmm(lhs_matrix, rhs_matrix_t)
                if mask is not None: dense_matrix += mask
                
    
    for i in range(10):
        with nvtx.annotate("sddmm"):
            nonzeros_sddmm, metadata_sddmm = dspattn.sddmm(lhs_matrix, rhs_matrix, mask)


if args.mode == 'verify':
    verify(lhs_matrix, rhs_matrix, mask)
else:
    profile(lhs_matrix, rhs_matrix, mask)