import torch
import dspattn
import nvtx
import argparse


parser = argparse.ArgumentParser(description="Verify/Profile the SDDMM kernel")
parser.add_argument('--batch_size', '-b', type=int, default=1, help='batch size')
parser.add_argument('--sequence_length', '-l', type=int, default=4096, help='sequence length')
parser.add_argument('--embedding', '-k', type=int, default=64, help='embedding dimension')
parser.add_argument('--dtype', '-t', choices=['float', 'bfloat16'], help='data type of operand')
parser.add_argument('--mode', '-m', choices=['verify', 'profile'], default='profile', help='verify the correctness or do profiling')
parser.add_argument('--row_sparsity', '-sp', type=float, default=0.5, help='The sparsity provided by blocked-ELL sparsity')
args = parser.parse_args()

# Currently, we only suppory block size (128 x 128)
block_size = 128


if args.dtype == 'float':
    dtype = torch.float32
else:
    dtype = torch.bfloat16

if args.batch_size > 1:
    dense_matrix = torch.randn(
        size=(args.batch_size, args.sequence_length, args.sequence_length), dtype=dtype, device='cuda')
    rhs_matrix = torch.randn(
        size=(args.batch_size, args.sequence_length, args.embedding), dtype=dtype, device='cuda')
else:
    dense_matrix = torch.randn(
        size=(args.sequence_length, args.sequence_length), dtype=dtype, device='cuda')
    rhs_matrix = torch.randn(size=(args.sequence_length, args.embedding), dtype=dtype, device='cuda')

# Generate the indices
total_n_blocks = int(args.sequence_length / 128)
total_m_blocks = int(args.sequence_length / 128)
nnz_n_blocks = int(total_n_blocks * (1. - args.row_sparsity))

indices = dspattn.static_random_mask(args.batch_size, total_m_blocks, total_n_blocks, nnz_n_blocks)

def verify(dense_matrix, rhs_matrix, indices):
    print("= Function Verification =")
    print("(B, M, N, K, nnz_block): (%d, %d, %d, %d, %d/%d), dtype: %s" % (
        args.batch_size, args.sequence_length, args.sequence_length, args.embedding, nnz_n_blocks, total_n_blocks, args.dtype))
    # Step 1: prune the dense matrix with 50% sparsity
    nonzeros, metadata, uncompressed = dspattn.dense2sparse(dense_matrix)
    # Step 2.1: prune the nonzeros
    nonzeros_bsp, nonzeros_bsp_up = dspattn.block_ell_prune(nonzeros, indices, int(block_size / 2))
    # Step 2.2: prune the meta data
    if args.dtype == 'float':
        metadata_bsp, _ = dspattn.meta_ell_prune(metadata, indices, int(block_size / 8))
    else:
        metadata_bsp, _ = dspattn.meta_ell_prune(metadata, indices, int(block_size / 16))
    # Step 2.3: prune the uncompressed
    _, uncompressed_bsp = dspattn.block_ell_prune(uncompressed, indices, block_size)

    output = dspattn.block_spmm(nonzeros_bsp, metadata_bsp, rhs_matrix, indices)
    if args.batch_size > 1:
        output_ref = torch.bmm(uncompressed_bsp, rhs_matrix)
    else:
        output_ref = torch.matmul(uncompressed_bsp, rhs_matrix)

    error = torch.abs(output - output_ref)

    print(
        "Totally %d (%.4f) mismatchs in the output matrix, the maximum error is %.2f" % (
        torch.where(error > 0.01)[0].numel(), 
        torch.where(error > 0.01)[0].numel() / output.numel(),
        float(torch.max(error))
        ))

def profile(dense_matrix, rhs_matrix, indices):
    # Step 1: prune the dense matrix with 50% sparsity
    nonzeros, metadata, uncompressed = dspattn.dense2sparse(dense_matrix)
    # Step 2.1: prune the nonzeros
    nonzeros_bsp, nonzeros_bsp_up = dspattn.block_ell_prune(nonzeros, indices, int(block_size / 2))
    # Step 2.2: prune the meta data
    if args.dtype == 'float':
        metadata_bsp, _ = dspattn.meta_ell_prune(metadata, indices, int(block_size / 8))
    else:
        metadata_bsp, _ = dspattn.meta_ell_prune(metadata, indices, int(block_size / 16))
    # Step 2.3: prune the uncompressed
    _, uncompressed_bsp = dspattn.block_ell_prune(uncompressed, indices, block_size)

    if args.batch_size > 1:
        for i in range(10):
            with nvtx.annotate("torch.bmm"):
                output_ref = torch.bmm(uncompressed_bsp, rhs_matrix)
    else:
        for i in range(10):
            with nvtx.annotate("torch.matmul"):
                output_ref = torch.matmul(uncompressed_bsp, rhs_matrix)
    
    for i in range(10):
        with nvtx.annotate("block spmm (%.2f)" % (0.5 + args.row_sparsity * 0.5)):
            output = dspattn.block_spmm(nonzeros_bsp, metadata_bsp, rhs_matrix, indices)




if args.mode == 'verify':
    verify(dense_matrix, rhs_matrix, indices)
else:
    profile(dense_matrix, rhs_matrix, indices)

