import click
import numpy as np
import scipy.sparse as sparse
import aggfuse_cpu

from test import csr_dmm_cpu, csr_fuse_cpu


def random_sparse(n, k, dtype, density, seed=0):
    return sparse.rand(
        n, k, density=density, format="csr", dtype=dtype, random_state=seed
    )


def random_dense(k, n, dtype, seed=0):
    rng = np.random.default_rng(seed)
    return rng.standard_normal(size=(k, n), dtype=dtype)


@click.command()
@click.argument("n", type=int)
@click.argument("k", type=int)
@click.argument("m", type=int)
@click.argument("d", type=float)
def main(n, k, m, d):
    s = random_sparse(n, k, dtype=np.float32, density=d)
    d = random_dense(k, m, dtype=np.float32)
    w = np.zeros((n, 3), dtype=np.float32)
    w[:, 0] = 1

    x_ref = s @ d
    x_reimple = csr_dmm_cpu(s, d)
    x_fuse = csr_fuse_cpu(s, d, w)

    print(x_ref, "\n\n\n\n")
    print(x_reimple, "\n\n\n")
    print(x_fuse)

    print(np.allclose(x_ref, x_reimple))
    diff = x_ref - x_reimple
    print(diff.max(), diff.min())

    print(np.allclose(x_ref, x_fuse))
    diff = x_ref - x_fuse
    print(diff.max(), diff.min())


if __name__ == "__main__":
    main()