import torch
import dspattn
import nvtx
import argparse


parser = argparse.ArgumentParser(description="Verify/Profile the SDDMM kernel")
parser.add_argument('--batch_size', '-b', type=int, default=1, help='batch size')
parser.add_argument('--sequence_length', '-l', type=int, default=4096, help='sequence length')
parser.add_argument('--embedding', '-k', type=int, default=64, help='embedding dimension')
parser.add_argument('--dtype', '-t', choices=['float', 'bfloat16'], help='data type of operand')
parser.add_argument('--mode', '-m', choices=['verify', 'profile'], default='profile', help='verify the correctness or do profiling')
parser.add_argument('--row_sparsity', '-sp', type=float, default=0.5, help='The sparsity provided by blocked-ELL sparsity')
args = parser.parse_args()

# Currently, we only suppory block size (128 x 128)
block_size = 128

if args.dtype == 'float':
    dtype = torch.float32
else:
    dtype = torch.bfloat16

if args.batch_size > 1:
    lhs_matrix = torch.randn(
        size=(args.batch_size, args.sequence_length, args.embedding), dtype=dtype, device='cuda')
    rhs_matrix = torch.randn(
        size=(args.batch_size, args.sequence_length, args.embedding), dtype=dtype, device='cuda')
else:
    lhs_matrix = torch.randn(
        size=(args.sequence_length, args.embedding), dtype=dtype, device='cuda')
    rhs_matrix = torch.randn(size=(args.sequence_length, args.embedding), dtype=dtype, device='cuda')

# Generate the indices
total_n_blocks = int(args.sequence_length / 128)
total_m_blocks = int(args.sequence_length / 128)
nnz_n_blocks = int(total_n_blocks * (1. - args.row_sparsity))

indices = dspattn.static_random_mask(args.batch_size, total_m_blocks, total_n_blocks, nnz_n_blocks)

def verify(lhs_matrix, rhs_matrix, indices):
    print("= Function Verification =")
    print("(B, M, N, K, nnz_block): (%d, %d, %d, %d, %d/%d), dtype: %s" % (
        args.batch_size, args.sequence_length, args.sequence_length, args.embedding, nnz_n_blocks, total_n_blocks, args.dtype))
    
    # Generate the output
    nonzeros_sddmm, metadata_sddmm = dspattn.block_sddmm(lhs_matrix, rhs_matrix, indices)
    # Generate the reference
    if lhs_matrix.dim() == 2:
        dense_matrix = torch.matmul(lhs_matrix, torch.transpose(rhs_matrix, 0, 1))
    else:
        dense_matrix = torch.bmm(lhs_matrix, torch.transpose(rhs_matrix, 1, 2))
    # 50% pruning
    nonzeros_ref, metadata_ref, _ = dspattn.dense2sparse(dense_matrix)
    # block pruning

    nonzeros_bsp, _ = dspattn.block_ell_prune(nonzeros_ref, indices, 64)
    if args.dtype == 'float':
        metadata_bsp, _ = dspattn.meta_ell_prune(metadata_ref, indices, 16)
    else:
        metadata_bsp, _ = dspattn.meta_ell_prune(metadata_ref, indices, 8)
    
    error_nnz = torch.abs(nonzeros_sddmm - nonzeros_bsp)
    error_meta = torch.abs(metadata_sddmm - metadata_bsp)

    print(
        "Totally %d (%.4f) mismatchs in the nonzeros, the maximum error is %.2f" % (
        torch.where(error_nnz > 0.0)[0].numel(), 
        torch.where(error_nnz > 0.0)[0].numel() / nonzeros_sddmm.numel(),
        float(torch.max(error_nnz))
        ))
    print(
        "Totally %d (%.4f) mismatchs in the meta data" % (
        torch.where(error_meta > 0)[0].numel(), 
        torch.where(error_meta > 0)[0].numel() / metadata_ref.numel()
        ))
    
if args.mode == 'verify':
    verify(lhs_matrix, rhs_matrix, indices)


""" float block sddmm
lhs_matrix = torch.randn(size=(4096, 64), dtype=torch.float32, device='cuda')
rhs_matrix = torch.randn(size=(4096, 64), dtype=torch.float32, device='cuda')

indices = dspattn.static_random_mask(1, 32, 32, 8)

# print(indices)

# Generate the output
nonzeros_sddmm, metadata_sddmm = dspattn.block_sddmm(lhs_matrix, rhs_matrix, indices)

# Generate the reference
dense_matrix = torch.matmul(lhs_matrix, torch.transpose(rhs_matrix, 0, 1))
# 50% pruning
nonzeros_ref, metadata_ref, _ = dspattn.dense2sparse(dense_matrix)
# block pruning

nonzeros_bsp, _ = dspattn.block_ell_prune(nonzeros_ref, indices, 64)
metadata_bsp, _ = dspattn.meta_ell_prune(metadata_ref, indices, 16)

print(torch.max(torch.abs(nonzeros_sddmm - nonzeros_bsp)))
print(torch.max(torch.abs(metadata_sddmm - metadata_bsp)))
"""

"""
lhs_matrix = torch.randn(size=(4096, 64), dtype=torch.bfloat16, device='cuda')
rhs_matrix = torch.randn(size=(4096, 64), dtype=torch.bfloat16, device='cuda')

indices = dspattn.static_random_mask(1, 32, 32, 8)

# print(indices)

# Generate the output
nonzeros_sddmm, metadata_sddmm = dspattn.block_sddmm(lhs_matrix, rhs_matrix, indices)

# Generate the reference
dense_matrix = torch.matmul(lhs_matrix, torch.transpose(rhs_matrix, 0, 1))
# 50% pruning
nonzeros_ref, metadata_ref, _ = dspattn.dense2sparse(dense_matrix)
# block pruning

nonzeros_bsp, _ = dspattn.block_ell_prune(nonzeros_ref, indices, 64)
metadata_bsp, _ = dspattn.meta_ell_prune(metadata_ref, indices, 8)

print(nonzeros_sddmm)
print(nonzeros_bsp)

print(metadata_sddmm)
print(metadata_bsp)


print(torch.max(torch.abs(nonzeros_sddmm - nonzeros_bsp)))
print(torch.max(torch.abs(metadata_sddmm - metadata_bsp)))
"""