import torch
import stk
from utils import get_time


blksz = 256
bs = 8192
num_exp = 64
topk = 8
hidden_size = 2560
ffn_hidden_size = 768


def test_sddmm_throughput(blksz, bs, num_exp, topk, hidden_size, ffn_hidden_size):
    data = torch.ones([bs * topk * ffn_hidden_size // (blksz**2), blksz, blksz], device='cuda', dtype=torch.float16)
    expert_score = torch.randn([bs // blksz, num_exp], device='cuda')
    _, expert_index = expert_score.topk(k=topk, dim=-1)

    row_idxs = expert_index.repeat_interleave(ffn_hidden_size // blksz, dim=-1).flatten().short()
    col_idxs = torch.arange(bs // blksz, device='cuda').repeat_interleave(topk * ffn_hidden_size // blksz).short()
    offsets = torch.arange(bs // blksz + 1, device='cuda') * topk
    offsets = offsets.int()

    shape = [bs, num_exp * ffn_hidden_size]

    stk.matrix._validate_matrix(shape, data, row_idxs, col_idxs, offsets)

    topo = stk.Matrix(
        shape,
        data,
        row_idxs,
        col_idxs,
        offsets
    )

    actv = torch.randn([bs, hidden_size], device='cuda', dtype=torch.float16)
    weight = torch.randn([hidden_size, ffn_hidden_size * num_exp], device='cuda', dtype=torch.float16)

    y = stk.ops.sdd(actv, weight, topo)


    t = get_time(lambda: stk.ops.sdd(actv, weight, topo))
    flops = bs * topk * hidden_size * ffn_hidden_size * 2 / t

    print(f'SDDMM Scale blksz {blksz} bs {bs} experts {topk}/{num_exp} hs {hidden_size}/{ffn_hidden_size}. '
          f'Time {t * 1e3:.3f} ms, {flops * 1e-12:.3f} TFLOPs')


for blksz in [128, 256]:
    for bs in [1024, 8192]:
        test_sddmm_throughput(blksz, bs, 64, 8, hidden_size, ffn_hidden_size)
