from typing import Tuple
import os
import time
import copy

import numpy as np
import torch
from torch import Tensor
from torch_sparse import SparseTensor
from utils.data import Data

# partition_fn = torch.ops.torch_sparse.mt_partition
partition_fn = torch.ops.torch_sparse.partition


def metis(adj_t: SparseTensor, num_parts: int, refine: bool = False,
          recursive: bool = False, log: bool = True) -> Tuple[Tensor, Tensor]:
    r"""Computes the METIS partition of a given sparse adjacency matrix
    :obj:`adj_t`, returning its "clustered" permutation :obj:`perm` and
    corresponding cluster slices :obj:`ptr`."""

    if log:
        t = time.perf_counter()
        print(f'[GriNNder] >>> METIS partitioning of {num_parts} parts...',
              end=' ', flush=True)

    num_nodes = adj_t.size(0)

    if num_parts <= 1:
        perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes])
    else:
        rowptr, col, _ = adj_t.csr()
        cluster = partition_fn(rowptr, col, None, num_parts, recursive)
        if refine:
            return cluster # return cluster while not sorted if refine is True
        cluster, perm = cluster.sort()
        ptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)

    if log:
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

    return perm, ptr


def permute(data: Data, perm: Tensor, log: bool = True) -> Data:
    r"""Permutes a :obj:`data` object according to a given permutation
    :obj:`perm`."""

    if log:
        t = time.perf_counter()
        print('[GriNNder] >>> Permuting data...', end=' ', flush=True)

    # data = copy.copy(data)
    for key, value in data:
        if isinstance(value, Tensor) and value.size(0) == data.num_nodes:
            data[key] = value[perm]
        elif isinstance(value, Tensor) and value.size(0) == data.num_edges:
            raise NotImplementedError
        elif isinstance(value, SparseTensor):
            data[key] = value.permute(perm)

    if log:
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

    return data
