"""
Adapted from D2Pruning
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
from sklearn.neighbors import kneighbors_graph
from sklearn.metrics import pairwise_distances
import abc
import numpy as np
import time
from collections import defaultdict
from tqdm import tqdm
import sys

class SamplingMethod(object):
  # From D2Pruning
  __metaclass__ = abc.ABCMeta

  @abc.abstractmethod
  def __init__(self, X, y, seed, **kwargs):
    self.X = X
    self.y = y
    self.seed = seed

  def flatten_X(self):
    shape = self.X.shape
    flat_X = self.X
    if len(shape) > 2:
      flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:])))
    return flat_X


  @abc.abstractmethod
  def select_batch_(self):
    return

  def select_batch(self, **kwargs):
    return self.select_batch_(**kwargs)

  def to_dict(self):
    return None


class InfoGraphDensitySampler(SamplingMethod):
  """Adapted from GraphDensitySampler@(D2Pruning)
  """
  def __init__(self, X, y, seed, gamma=None, importance_scores=None, args=None):
    self.name = 'infomax_graph_density' 
    max_solver_it = 5
    self.args = args
    self.X = X
    if self.X is not None:
      self.flat_X = self.flatten_X()
    if gamma is not None:
      self.gamma = gamma
    else:
      self.gamma = 1. / self.X.shape[1]
    self.graph_mode = args.graph_mode
    self.graph_sampling_mode = args.graph_sampling_mode  
    self.Build_GCCG_Graph(args.n_neighbor_post, importance_scores, max_solver_it)
    
  def Build_GCCG_Graph(self, n_neighbor=10, importance_scores=None, max_solver_it=1):
    self.distances = pairwise_distances(self.flat_X, self.flat_X) 
    if importance_scores is not None and self.graph_mode in ['sum', 'product']:
      epsilon = 0.0000001 
      n_samples = self.flat_X.shape[0] 
      if n_neighbor > self.flat_X.shape[0]:
            n_neighbor = self.flat_X.shape[0] -10
      connect = kneighbors_graph(self.flat_X, n_neighbor,p=2)
      connect = connect.todense() 
      neighbors = connect.nonzero()
      inds = zip(neighbors[0], neighbors[1]) 
      for iteration in range(max_solver_it):
        for entry in inds:
            i = entry[0]
            j = entry[1] 
            distance = self.distances[i, j] 
            if iteration == 0:
                weight_j = (1.0-np.exp(-distance + epsilon)) * max(importance_scores[j].item(), epsilon)
                weight_i = (1.0-np.exp(-distance + epsilon)) * max(importance_scores[i].item(), epsilon)
            else:
                weight_j = (1.0-np.exp(-distance + epsilon)) * (self.graph_density[j].item()) 
                weight_i = (1.0-np.exp(-distance + epsilon)) * (self.graph_density[i].item()) 
            connect[i, j] = weight_j 
            connect[j, i] = weight_i 
        self.connect_post = connect 
  def select_batch_(self, N, **kwargs):
    self.connect_post = self.connect
    # Here, we suggest to use the greedy selection strategy as a post-processing step 
    # (we find that it is more stable than the just pick out those samples with largest final scores, 
    # especially when the selection ratio is small)
    post_processing = True
    if post_processing: 
      batch = set()
      while len(batch) < N:
        selected = np.argmax(self.graph_density) 
        if type(self.connect_post) == dict: 
          pass 
        else: 
          neighbors = (self.connect_post[selected,:] > 0).nonzero()[1] 
        self.graph_density[neighbors] = self.graph_density[neighbors] - np.exp(-self.distances[selected, neighbors]*self.gamma)*self.graph_density[selected] 
        batch.add(selected) 
        self.graph_density[list(batch)] = min(self.graph_density) - 1
    else:
      batch = set()
      while len(batch) < N: 
        selected = np.argmax(self.graph_density)
        batch.add(selected) 
        self.graph_density[list(batch)] = min(self.graph_density) - 1e10
    return list(batch)

