from statistics import mean, stdev
import time

import click
from scipy.sparse.csr import csr_matrix
import torch
from torch_sparse import SparseTensor
import numpy as np
import scipy.sparse as sparse
import torch_sparse
import aggfuse_cpu
import aggfuse_gpu

from torch_geometric.data import Batch
from torch_geometric.datasets import (
    Planetoid,
    Reddit,
    ZINC,
    SuiteSparseMatrixCollection,
)
from torch_geometric.utils import to_scipy_sparse_matrix

from experiments.code.utils import code_data
from experiments.arxiv.configs import arxiv_data


def random_sparse(n, k, dtype, density, seed=0):
    return sparse.rand(
        n, k, density=density, format="csr", dtype=dtype, random_state=seed
    )


def random_dense(k, n, dtype, seed=0):
    rng = np.random.default_rng(seed)
    return rng.standard_normal(size=(k, n), dtype=dtype)


def time_fn(f, warmups, runs):
    for _ in range(warmups):
        f()
    times = []
    for _ in range(runs):
        start = time.time()
        f()
        times.append(time.time() - start)

    return times


def mm_cpu(x, theta):
    return np.matmul(x, theta)


def csr_dmm_cpu(s, d):
    out = np.zeros((s.shape[0], d.shape[1]), dtype=np.float32)
    aggfuse_cpu.csr_sum(s.shape[0], s.shape[1], s.indptr, s.indices, s.data, d, out)
    return out


def csr_fuse_cpu(s, d, w):
    out = np.zeros((s.shape[0], d.shape[1]), dtype=np.float32)
    aggfuse_cpu.aggfuse_fp32(
        s.shape[0], s.shape[1], s.indptr, s.indices, s.data, d, w, out
    )
    return out


def csr_fuse_mat_cpu(s, d, w):
    out_mat = np.zeros((3, s.shape[0], d.shape[1]), dtype=np.float32)
    out = np.zeros((s.shape[0], d.shape[1]), dtype=np.float32)
    aggfuse_cpu.aggfuse_fp32_mat(
        s.shape[0], s.shape[1], s.indptr, s.indices, s.data, d, w, out_mat, out
    )
    return out


def naive_fuse_cpu(s, d, w):
    out_sum = np.zeros((s.shape[0], d.shape[1]), dtype=np.float32)
    out_max = np.zeros((s.shape[0], d.shape[1]), dtype=np.float32)
    out_min = np.zeros((s.shape[0], d.shape[1]), dtype=np.float32)

    aggfuse_cpu.csr_sum(s.shape[0], s.shape[1], s.indptr, s.indices, s.data, d, out_sum)
    aggfuse_cpu.csr_max(s.shape[0], s.shape[1], s.indptr, s.indices, s.data, d, out_max)
    aggfuse_cpu.csr_min(s.shape[0], s.shape[1], s.indptr, s.indices, s.data, d, out_min)

    w = np.expand_dims(w, -1)
    out = (w[:, 0] * out_sum) + (w[:, 1] * out_max) + (w[:, 2] * out_min)
    return out


def mm_gpu(x, theta):
    with torch.no_grad():
        y = torch.matmul(x, theta)
        torch.cuda.synchronize()
        return y


def csr_dmm_gpu(s, d):
    with torch.no_grad():
        y = torch_sparse.matmul(s, d, reduce="sum")
        torch.cuda.synchronize()
        return y


def csr_fuse_gpu(s, d, w):
    with torch.no_grad():
        rowptr, col, value = s.csr()
        y = aggfuse_gpu.ts_fuse_fp32(rowptr, col, value, d, w, False)
        torch.cuda.synchronize()
        return y


def csr_fuse_mat_gpu(s, d, w):
    with torch.no_grad():
        rowptr, col, value = s.csr()
        y = aggfuse_gpu.ts_fuse_fp32(rowptr, col, value, d, w, True)
        torch.cuda.synchronize()
        return y


def naive_fuse_gpu(s, d, w):
    with torch.no_grad():
        y_sum = torch_sparse.matmul(s, d, reduce="sum")
        y_min = torch_sparse.matmul(s, d, reduce="min")
        y_max = torch_sparse.matmul(s, d, reduce="max")
        w = w.unsqueeze(-1)
        y = (y_sum * w[:, 0]) + (y_min * w[:, 1]) + (y_max * w[:, 2])
        torch.cuda.synchronize()
        return y


def load_cora(root):
    return Planetoid(root=root, name="Cora")[0]


def load_reddit(root):
    return Reddit(root)[0]


def load_zinc(root):
    dataset = ZINC(root, subset=True)
    batch = Batch.from_data_list([dataset[i] for i in range(10000)])
    return batch


def load_code(root):
    data = code_data(root=root, batch_size=128)
    dataset = data[0]["train"].dataset
    batch = Batch.from_data_list([dataset[i] for i in range(10000)])
    return batch


def load_circuit(root):
    dataset = SuiteSparseMatrixCollection(root=root, group="Freescale", name="memchip")
    return dataset[0]


def load_arxiv(root):
    dataset = arxiv_data(root)
    return dataset[0]


def _mean_std(data):
    return mean(data), stdev(data)


def to_sparse_cpu(data):
    N = data.num_nodes
    row, col = data.edge_index
    row = row.numpy()
    col = col.numpy()
    edge_attr = torch.ones(row.size).numpy()
    return sparse.csr_matrix((edge_attr, (row, col)), (N, N)).astype(np.float32)


def to_dense_cpu(x):
    return x


def to_sparse_gpu(data):
    (row, col), N = data.edge_index, data.num_nodes
    perm = (col * N + row).argsort()
    row, col = row[perm], col[perm]
    value = torch.ones(data.edge_index.shape[1])
    adj_t = SparseTensor(
        row=col, col=row, value=value, sparse_sizes=(N, N), is_sorted=True
    )

    # Pre-process some important attributes.
    adj_t.storage.rowptr()
    adj_t.storage.csr2csc()
    return adj_t.to(torch.float32).to("cuda")


def to_dense_gpu(x):
    x = torch.from_numpy(x).to("cuda")
    return x


@click.command()
@click.argument(
    "dataset", type=click.Choice(["reddit", "cora", "zinc", "circuit", "code", "arxiv"])
)
@click.argument("k", type=int)
@click.argument("device", type=click.Choice(["cpu", "gpu"]))
@click.option("--data_dir", type=click.Path(), default="~/datasets")
@click.option("--warmups", type=int, default=5)
@click.option("--runs", type=int, default=5)
def main(dataset, k, device, data_dir, warmups, runs):
    if dataset == "reddit":
        data = load_reddit(data_dir)
    elif dataset == "cora":
        data = load_cora(data_dir)
    elif dataset == "zinc":
        data = load_zinc(data_dir)
    elif dataset == "circuit":
        data = load_circuit(data_dir)
    elif dataset == "code":
        data = load_code(data_dir)
    elif dataset == "arxiv":
        data = load_arxiv(data_dir)
    else:
        raise ValueError

    n = data.num_nodes
    d = random_dense(n, k, dtype=np.float32)
    w = random_dense(n, 3, np.float32)
    theta = random_dense(k, k, dtype=np.float32)

    if device == "cpu":
        s = to_sparse_cpu(data)
        d = to_dense_cpu(d)
        w = to_dense_cpu(w)
        theta = to_dense_cpu(theta)
        dmm = lambda: mm_cpu(d, theta)
        reimple = lambda: csr_dmm_cpu(s, d)
        fused = lambda: csr_fuse_cpu(s, d, w)
        mat_fuse = lambda: csr_fuse_mat_cpu(s, d, w)
        naive_fused = lambda: naive_fuse_cpu(s, d, w)
    else:
        s = to_sparse_gpu(data)
        d = to_dense_gpu(d)
        w = to_dense_gpu(w)
        theta = to_dense_gpu(theta)
        dmm = lambda: mm_gpu(d, theta)
        reimple = lambda: csr_dmm_gpu(s, d)
        fused = lambda: csr_fuse_gpu(s, d, w)
        mat_fuse = lambda: csr_fuse_mat_gpu(s, d, w)
        naive_fused = lambda: naive_fuse_gpu(s, d, w)

    m, std = _mean_std(time_fn(dmm, warmups, runs))
    print(f"{dataset},{device},{runs},{k},dmm,{m},{std}")
    m, std = _mean_std(time_fn(reimple, warmups, runs))
    print(f"{dataset},{device},{runs},{k},csr_sum,{m},{std}")
    m, std = _mean_std(time_fn(fused, warmups, runs))
    print(f"{dataset},{device},{runs},{k},fused,{m},{std}")
    m, std = _mean_std(time_fn(mat_fuse, warmups, runs))
    print(f"{dataset},{device},{runs},{k},matfused,{m},{std}")
    m, std = _mean_std(time_fn(naive_fused, warmups, runs))
    print(f"{dataset},{device},{runs},{k},naive,{m},{std}")


if __name__ == "__main__":
    main()
