import click
import torch

from test import load_arxiv, load_reddit, to_sparse_gpu, csr_dmm_gpu, csr_fuse_gpu


def dense(iters, b, n):
    print("dense", b, n)
    x = torch.randn((b, n)).cuda()
    w = torch.randn((n, n)).cuda()
    with torch.autograd.profiler.emit_nvtx():
        for _ in range(iters):
            torch.matmul(x, w)


def sparse(iters, a, n, fuse):
    b = a.size(0)
    print("sparse", b, n, fuse)
    x = torch.randn((b, n)).cuda()
    w = torch.randn((b, 3)).cuda()
    with torch.autograd.profiler.emit_nvtx():
        for _ in range(iters):
            if fuse:
                csr_fuse_gpu(a, x, w)
            else:
                csr_dmm_gpu(a, x)



@click.command()
@click.argument("root", type=click.Path())
@click.argument("mode", type=click.Choice(["dense", "sparse"]))
@click.argument("n", type=int)
@click.option("--iters", type=int, default=10)
@click.option("--fuse", is_flag=True)
@click.option("--dataset", type=click.Choice(["reddit", "arxiv"]), default="arxiv")
def main(root, mode, n, iters, fuse, dataset):
    if dataset == "arxiv":
        data = load_arxiv(root)
    else:
        data = load_reddit(root)
    if mode == "dense":
        dense(iters, data.num_nodes, n)
    
    if mode == "sparse":
        sparse(iters, to_sparse_gpu(data), n, fuse)


if __name__ == "__main__":
    main()
