import torch
import dspattn
import nvtx
import argparse
from torch._C import wait
import torch.nn.functional as F


parser = argparse.ArgumentParser(description="Verify/Profile the SDDMM kernel")
parser.add_argument('--batch_size', '-b', type=int, default=1, help='batch size')
parser.add_argument('--sequence_length', '-l', type=int, default=4096, help='sequence length')
parser.add_argument('--embedding', '-k', type=int, default=2048, help='embedding dimension')
parser.add_argument('--dtype', '-t', choices=['float', 'bfloat16'], help='data type of operand')
parser.add_argument('--mode', '-m', choices=['verify', 'profile'], default='profile', help='verify the correctness or do profiling')
args = parser.parse_args()

if args.dtype == 'float':
    dtype = torch.float32
else:
    dtype = torch.bfloat16

dense_matrix = torch.randn(
        size=(args.batch_size, args.sequence_length, args.embedding), dtype=dtype, device='cuda')


def verify(dense_matrix):
    print("= Function Verification =")
    print("(B, M, K): (%d, %d, %d), dtype: %s" % (args.batch_size, args.sequence_length, args.embedding, args.dtype))

    output_matrix = dspattn.softmax(dense_matrix, 0.5)
    output_matrix_ref = F.softmax(dense_matrix *  0.5, dim=-1)

    error = torch.abs(output_matrix - output_matrix_ref)

    print(
        "Totally %d (%.4f) mismatchs in the output matrix, the maximum error is %.2f" % (
        torch.where(error > 0.01)[0].numel(), 
        torch.where(error > 0.01)[0].numel() / output_matrix.numel(),
        float(torch.max(error))
        ))

def profile(dense_matrix):
    for i in range(10):
        with nvtx.annotate("F.softmax"):
            output_matrix_ref = F.softmax(dense_matrix, dim=-1)
    
    for i in range(10):
        with nvtx.annotate("softmax"):
            output_matrix = dspattn.softmax(dense_matrix, 1.)

if args.mode == 'verify':
    verify(dense_matrix)
else:
    profile(dense_matrix)