"""
Our implementation for random partitioning
"""
from typing import Tuple
import time
import copy

import numpy as np
import torch
from torch import Tensor
from torch_sparse import SparseTensor

from utils.data import Data
from utils.data import get_data
from utils.compiler import load_grinnder_ext

# load grinnder external modules
grinnder_ext = load_grinnder_ext()

partition_fn = grinnder_ext.random_partition

def random_partition(adj_t: SparseTensor, num_parts: int,
               log: bool = False) -> Tuple[Tensor, Tensor]:
    r"""Computes the random partition of a given sparse adjacency matrix
    """
    if log:
        t = time.perf_counter()
        print(f'[GriNNder] >>> Multi-threaded Spinner partitioning of {num_parts} parts...')

    num_nodes = adj_t.size(0)

    if num_parts <= 1:
        perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes])
    else:
        rowptr, col, _ = adj_t.csr()
        cluster = partition_fn(rowptr, col, num_parts, log)
        assert cluster.shape[0] == num_nodes, 'partition tensor does not match with the number of nodes'
        cluster, perm = cluster.sort()
        ptr = torch.ops.torch_sparse.ind2ptr(cluster.long(), num_parts)

    if log:
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

    return perm, ptr

def permute(data: Data, perm: Tensor, log: bool = True) -> Data:
    r"""Permutes a :obj:`data` object according to a given permutation
    :obj:`perm`."""

    if log:
        t = time.perf_counter()
        print('[GriNNder] >>> Permuting data...', end=' ', flush=True)

    data = copy.copy(data)
    for key, value in data:
        if isinstance(value, Tensor) and value.size(0) == data.num_nodes:
            data[key] = value[perm]
        elif isinstance(value, Tensor) and value.size(0) == data.num_edges:
            raise NotImplementedError
        elif isinstance(value, SparseTensor):
            data[key] = value.permute(perm)

    if log:
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

    return data
