from networkx.algorithms.approximation import christofides, greedy_tsp
import networkx as nx
from utils import tour_length
import time

_NORMALIZATION = ['none', 'n', 'opt']

def chrp(G, prediction='prediction', normalize='none'):
  """
  Christofides augmented with predictions (CHR^+).

  Parameters
  ----------
  G : networkx.Graph
    G should be a complete weighted undirected graph. Edge data key
    corresponding to the edge weight should be 'weight'. Weights should satisfy
    triangle inequality.

  prediction : string, optional (default: 'prediction')
    Edge data key corresponding to the predictions. Prediction values should be
    between 0 and 1.

  normalize : string, optional (default: 'none')
    How to normalize predictions. Accepted values: 'none', 'n', 'opt'.

  """
  start = time.time()
  assert normalize in _NORMALIZATION
  if normalize == 'none':
    coef = 1.0
  elif normalize == 'n':
    coef = G.number_of_edges() / sum(G[i][j][prediction] for i, j in G.edges())
  else:
    opt_apx = tour_length(G, greedy_tsp(G))
    weighted_sum = sum(G[i][j]['weight'] * G[i][j][prediction] for i, j in G.edges())
    coef = opt_apx / weighted_sum

  for i, j in G.edges():
    G[i][j]['weight_prime'] = G[i][j]['weight'] * (1 - coef * G[i][j][prediction])
  tree = nx.minimum_spanning_tree(G, weight='weight_prime')
  return christofides(G, tree=tree), time.time() - start