import torch
import dspattn
import nvtx
import argparse

parser = argparse.ArgumentParser(description="Verify/Profile the SDDMM kernel")
parser.add_argument('--batch_size', '-b', type=int, default=1, help='batch size')
parser.add_argument('--sequence_length', '-l', type=int, default=4096, help='sequence length')
parser.add_argument('--embedding', '-k', type=int, default=64, help='embedding dimension')
parser.add_argument('--dtype', '-t', choices=['float', 'bfloat16'], help='data type of operand')
parser.add_argument('--mode', '-m', choices=['verify', 'profile'], default='profile', help='verify the correctness or do profiling')
args = parser.parse_args()

if args.dtype == 'float':
    dtype = torch.float32
else:
    dtype = torch.bfloat16


if args.batch_size > 1:
    dense_matrix = torch.randn(
        size=(args.batch_size, args.sequence_length, args.sequence_length), dtype=dtype, device='cuda')
    rhs_matrix = torch.randn(
        size=(args.batch_size, args.sequence_length, args.embedding), dtype=dtype, device='cuda')
else:
    dense_matrix = torch.randn(
        size=(args.sequence_length, args.sequence_length), dtype=dtype, device='cuda')
    rhs_matrix = torch.randn(size=(args.sequence_length, args.embedding), dtype=dtype, device='cuda')


def verify(dense_matrix, rhs_matrix):
    print("= Function Verification =")
    print("(B, M, N, K): (%d, %d, %d, %d), dtype: %s" % (args.batch_size, args.sequence_length, args.sequence_length, args.embedding, args.dtype))
    nonzeros, metadata, uncompressed = dspattn.dense2sparse(dense_matrix)
    if dense_matrix.dim() == 2:
        output_matrix_ref = torch.matmul(uncompressed, rhs_matrix)
    else:
        output_matrix_ref = torch.bmm(uncompressed, rhs_matrix)
    
    output_matrix = dspattn.spmm(nonzeros, metadata, rhs_matrix)

    error = torch.abs(output_matrix - output_matrix_ref)

    print(
        "Totally %d (%.4f) mismatchs in the output matrix, the maximum error is %.2f" % (
        torch.where(error > 0.01)[0].numel(), 
        torch.where(error > 0.01)[0].numel() / output_matrix.numel(),
        float(torch.max(error))
        ))


def profile(dense_matrix, rhs_matrix):
    if dense_matrix.dim() == 2:
        for i in range(10):
            with nvtx.annotate("torch.matmul"):
                output_matrix_ref = torch.matmul(dense_matrix, rhs_matrix)
    else:
        for i in range(10):
            with nvtx.annotate("torch.bmm"):
                output_matrix_ref = torch.bmm(dense_matrix, rhs_matrix)
    
    nonzeros, metadata, uncompressed = dspattn.dense2sparse(dense_matrix)
    for i in range(10):
        with nvtx.annotate("spmm"):
            output_matrix = dspattn.spmm(nonzeros, metadata, rhs_matrix)

if args.mode == 'verify':
    verify(dense_matrix, rhs_matrix)
else:
    profile(dense_matrix, rhs_matrix)