"""
Our customized spinner
"""

from typing import Tuple, Optional
import time
import sys

from loguru import logger

import torch
from torch import Tensor
from torch_sparse import SparseTensor

from utils.data import Data
from utils.data import get_data
from utils.compiler import load_grinnder_ext


# load grinnder external modules
grinnder_ext = load_grinnder_ext()

def partition(partition_method: str,
              adj_t: SparseTensor, num_parts: int,
               capacity: float = 1.10, beta: float = 1.00,
               max_iter: int = 300, progressive_window: int = 5,
               halting_eps: float = 0.0001, halting_window: int = 5,
               async_execution: bool = True, reuse_aware: bool = True, refine: bool = False,
               recursive: bool = False,
               log: bool = False, num_threads: int = 32, cluster: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
    
    if log:
        t = time.perf_counter()
        logger.info(f'Partitioning method: {partition_method}')
        logger.info(f'Partitioning of {num_parts} parts...')

    num_nodes = adj_t.size(0)

    if num_parts <= 1:
        perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes])
    else:
        rowptr, col, _ = adj_t.csr()
        if partition_method == 'grinnder':
            max_iter = 50
            partition_fn = grinnder_ext.fast_grinnder
            if refine:
                cluster = partition_fn(rowptr, col, num_parts, capacity, beta, max_iter,
                                    progressive_window,
                                    halting_eps, halting_window,
                                    async_execution, reuse_aware, refine,
                                    log, num_threads, cluster.short())
            else:
                cluster = partition_fn(rowptr, col, num_parts, capacity, beta, max_iter,
                                    progressive_window,
                                    halting_eps, halting_window,
                                    async_execution, reuse_aware, refine,
                                    log, num_threads, None)
        elif partition_method == 'spinner':
            max_iter = 50
            partition_fn = grinnder_ext.spinner
            cluster = partition_fn(rowptr, col, num_parts, capacity, beta, max_iter,
                                halting_eps, halting_window,
                                async_execution, log, num_threads)
        elif partition_method == 'metis':
            partition_fn = torch.ops.torch_sparse.partition
            cluster = partition_fn(rowptr, col, None, num_parts, recursive)
            if refine:
                return cluster # return cluster while not sorted if refine is True
        elif partition_method == 'random':
            partition_fn = grinnder_ext.random_partition
            cluster = partition_fn(rowptr, col, num_parts, True)
        else:
            logger.error('Unknown partitioning method...')
            sys.exit(1)

        assert cluster.shape[0] == num_nodes, 'partition tensor does not match with the number of nodes'
        cluster, perm = cluster.sort()
        ptr = torch.ops.torch_sparse.ind2ptr(cluster.long(), num_parts)
    
    if log:
        logger.info(f'Done! [{time.perf_counter() - t:.2f}s]')

    return perm, ptr

def permute(data: Data, perm: Tensor, log: bool = True) -> Data:
    r"""Permutes a :obj:`data` object according to a given permutation
    :obj:`perm`."""

    if log:
        t = time.perf_counter()
        logger.info(f'Permuting data...')

    # data = copy.copy(data)
    for key, value in data:
        if isinstance(value, Tensor) and value.size(0) == data.num_nodes:
            data[key] = value[perm]
        elif isinstance(value, Tensor) and value.size(0) == data.num_edges:
            raise NotImplementedError
        elif isinstance(value, SparseTensor):
            data[key] = value.permute(perm)
    if log:
        logger.info(f'Done! [{time.perf_counter() - t:.2f}s]')

    return data