"""
Our implementation for Spinner
: we use the same implementation as the original Spinner paper
: Please note that I did not implement the dynamic cases because we do not need them.
: but we change the objective function for GriNNder
"""

from typing import Tuple
import time
import copy

import numpy as np
import torch
from torch import Tensor
from torch_sparse import SparseTensor

from utils.data import Data
from utils.data import get_data
from utils.compiler import load_grinnder_ext


# load grinnder external modules
grinnder_ext = load_grinnder_ext()

partition_fn = grinnder_ext.spinner

"""
Procedure of Spinner
1. each vertex computes the score for each label based on loads and the labels from its neighbors.
   If a new partition has a higher score, the vertex migrates to the new partition.
2. interested vertices try to migrate according to the ratio of the vertices trying to migrate
   and the remaining capacity of the target partition. Vertices that succeed in the migration update
   the partition loads and communiocate their labels to their neighbors.
"""

def mt_spinner(adj_t: SparseTensor, num_parts: int,
               capacity: float = 1.10, beta: float = 1.00,
               max_iter: int = 300,
               halting_eps: float = 0.001, halting_window: int = 5,
               async_execution: bool = True, log: bool = False,
               num_threads: int = 32) -> Tuple[Tensor, Tensor]:
    r"""Computes the Spinner partition of a given sparse adjacency matrix
    """
    if log:
        t = time.perf_counter()
        print(f'[GriNNder] >>> Multi-threaded Spinner partitioning of {num_parts} parts...')

    num_nodes = adj_t.size(0)

    if num_parts <= 1:
        perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes])
    else:
        rowptr, col, _ = adj_t.csr()
        cluster = partition_fn(rowptr, col, num_parts, capacity, beta, max_iter,
                               halting_eps, halting_window,
                               async_execution, log, num_threads)
        assert cluster.shape[0] == num_nodes, 'partition tensor does not match with the number of nodes'
        cluster, perm = cluster.sort()
        ptr = torch.ops.torch_sparse.ind2ptr(cluster.long(), num_parts)

    if log:
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

    return perm, ptr

def spinner(rowptr: Tensor, col: Tensor, num_parts: int,
            capacity: float = 1.05, beta: float = 1.00,
            halting_eps: float = 0.01, halting_window: int = 5,
            log: bool = False) -> Tuple[Tensor, Tensor]:
   r"""Computes the Spinner partition of a given sparse adjacency matrix
   """
   if log:
      t = time.perf_counter()
      print(f'[GriNNder] >>> Spinner partitioning of {num_parts} parts...')
   
   max_iter = 300

   capacity = capacity * len(col) / num_parts # C = c*|E|/k

   assert num_parts <= 32768, 'We do not support more than 32768 partitions'
   labels = torch.randint(0, num_parts, (len(rowptr)-1,), dtype=torch.short)

   labeled_col = torch.zeros_like(col, dtype=torch.short)
   balance_per_label = torch.zeros(num_parts, dtype=torch.float)

   graph_scores = torch.zeros(num_parts, dtype=torch.float)
   prev_graph_scores = torch.zeros(num_parts, dtype=torch.float)
   
   halt = False

   interested_vertices = [list() for _ in range(num_parts)]

   iter = 0
   window = 0
   while(iter < max_iter and not halt):
      iter += 1

      graph_scores.fill_(0)
      labeled_col.fill_(0)
      balance_per_label.fill_(0)
      interested_vertices = [list() for _ in range(num_parts)]

      # update edge with the label of the source vertex
      for i, e in enumerate(col):
         labeled_col[i] = labels[e]

      # update the partition loads
      for v in range(len(rowptr)-1):
         balance_per_label[labels[v]] += rowptr[v+1] - rowptr[v]

      # calculate the penalty term
      penalty_term = balance_per_label / capacity
      remaining_capacity = capacity - balance_per_label

      # print(remaining_capacity)
      # assert False, 'Not implemented yet'

      print(f'ITER: {iter}, MAX_LOAD: {balance_per_label.max()}, p (MAX_LOAD/(|E|/k)): {balance_per_label.max() / (len(col) / num_parts)}, ')
      print(f' penalty_term: {penalty_term}, ')

      # We need to keep interested vertices to migrate
      # they migrate according to the ratio of the vertices who want to migrate
      # to the remaining capacity of the target partition
      # vertices who succeed in the migration update the partition loads and communicate their labels to their neighbors
      # then we only update related values for the only migrating vertice (not for non-migrating vertices)

      for v in range(len(rowptr)-1):
         if rowptr[v+1] == rowptr[v]:
            continue

         cur_v_label = labels[v]
         
         cur_v_score = torch.bincount(labeled_col[rowptr[v]:rowptr[v+1]], minlength=num_parts)

         cur_v_score = cur_v_score / (rowptr[v+1] - rowptr[v])
         # cur_v_score = cur_v_score / cur_v_score.sum()
         cur_v_score = beta + cur_v_score - beta * penalty_term / num_parts
         
         # update the total graph score
         graph_scores[cur_v_label] += cur_v_score[cur_v_label]

         max_val, max_label = torch.max(cur_v_score, dim=0)
         if max_val != cur_v_score[cur_v_label]: # if max_val is larger than the current label
            interested_vertices[max_label.item()].append(v)
            # labels[v] = max_label # temporal impl. we need to consider the capacity

      num_interested = torch.tensor([len(v) for v in interested_vertices], dtype=torch.float)
      migration_prob = remaining_capacity / num_interested
      migration_prob = torch.where(migration_prob <= 0, 0.0, migration_prob)
      migration_prob = torch.where(migration_prob > 1, 1.0, migration_prob)

      num_migrate = torch.tensor([int(p * len(v)) for p, v in zip(migration_prob, interested_vertices)], dtype=torch.int)

      for i, (migrate, vertices) in enumerate(zip(num_migrate, interested_vertices)):
         if migrate == 0:
            continue
         vertices = torch.tensor(vertices, dtype=torch.int)
         unif = torch.ones(len(vertices), dtype=torch.float)
         idx = unif.multinomial(migrate, replacement=False).sort().values
         for v in vertices[idx]:
            labels[v] = i

      # check the halting condition
      cur_score = graph_scores.sum()
      prev_score = prev_graph_scores.sum()
      step = (1 - cur_score / prev_score).abs()
      print(f'DELTA: {step}, CUR_SCORE: {cur_score}, PREV_SCORE: {prev_score}')
      
      if step <= halting_eps:
         window += 1
         if window >= halting_window:
            halt = True
      else:
         window = 0

         prev_graph_scores = graph_scores.clone()
      

   print(torch.bincount(labels, minlength=num_parts))

   if log:
      print(f'Done! [{time.perf_counter() - t:.2f}s]')

def permute(data: Data, perm: Tensor, log: bool = True) -> Data:
    r"""Permutes a :obj:`data` object according to a given permutation
    :obj:`perm`."""

    if log:
        t = time.perf_counter()
        print('[GriNNder] >>> Permuting data...', end=' ', flush=True)

   #  data = copy.copy(data)
    for key, value in data:
        if isinstance(value, Tensor) and value.size(0) == data.num_nodes:
            data[key] = value[perm]
        elif isinstance(value, Tensor) and value.size(0) == data.num_edges:
            raise NotImplementedError
        elif isinstance(value, SparseTensor):
            data[key] = value.permute(perm)

    if log:
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

    return data

if __name__ == '__main__':
   print('[GriNNder] Spinner Unit Test')

   n_parts = 16
   capacity = 1.05
   halting_eps = 0.001
   halting_window = 5

   # Loading data
   print('[GriNNder] >>> loading data...')
   data, n_feats, n_classes = get_data('/datasets/grinnder', 'reddit')
   adj_t = data.adj_t.copy()
   del data

   print('[GriNNder] >>> loaded sparse data!')
   print(adj_t)
   rowptr, col, _ =  adj_t.csr()
   del adj_t

   print('[GriNNder] >> converted to csr format!')
   print(rowptr)
   print(col)

   # Spinner
   spinner(rowptr, col, n_parts, log=True)


   print('[GriNNder] Spinner Unit Test Done')