import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
import nvtx
from attention import MultiheadAttention as MultiheadAttention_

import argparse

parser = argparse.ArgumentParser(description='Sparse Multihead Attention')
# Hyper-parameter of the self-attention module
parser.add_argument('--embed_dim', type=int, default=512, help='The embedding dimension. We have head_dim * num_heads = embed_dim')
parser.add_argument('--num_heads', type=int, default=8, help='The number of attention heads')
# Hyper-parameter of the input size
parser.add_argument('--bs', type=int, default=4, help='batch size')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
# For sparse mask
parser.add_argument('--mask', type=str, default='./appr_mask_Feb05c.npy', 
                    help='the .npy file holding the sparse mask')
args = parser.parse_args()


# Construct the multihead self-attention module
torch_attn = MultiheadAttention_(embed_dim=args.embed_dim, num_heads=args.num_heads, dropout=0).cuda()


torch_attn.eval()

# Initialize the inputs
query = torch.randn(size=(args.seq_len, args.bs, args.embed_dim), dtype=torch.float32, device='cuda')

# Load the mask from directory
# mask = np.squeeze(np.load(args.mask))
# mask = torch.cuda.BoolTensor(mask, device='cuda')


# Warmup
for i in range(10):
    out = torch_attn(query, query, query, need_weights=False)

# profile:
# for i in range(5):
#     with nvtx.annotate("MultiheadAttention_ w/ mask"):
#         out = torch_attn(query, query, query, need_weights=False, attn_mask=mask)

for i in range(5):
    with nvtx.annotate("MultiheadAttention_ w/o mask"):
        out = torch_attn(query, query, query, need_weights=False)
