"""Bourgain embedding computation."""
from typing import Any, Optional

import networkx as nx
import numpy as np


def get_random_anchor_sets(n, c=0.5):
  m = int(np.log2(n))
  copy = int(c * m)
  anchor_set_id = []
  for i in range(m):
    anchor_size = int(n / np.exp2(i + 1))
    for _ in range(copy):
      anchor_set_id.append(np.random.choice(n, size=anchor_size, replace=False))
  return anchor_set_id


def compute_distances_to_anchor_set(n: int, senders: np.ndarray,
                                    receivers: np.ndarray,
                                    anchor_set: np.ndarray):
  """Computes distances to a given set of nodes.

  Args:
    n: The number of nodes in the graph
    senders: The sender nodes of the graph
    receivers: The receiver nodes of the graph
    anchor_set: The set of nodes to which we will compute distance.

  Returns:
    distances
  """
  node_map = np.array(range(0, n))
  node_map[anchor_set] = n
  senders = node_map[senders]
  receivers = node_map[receivers]

  graph = nx.Graph()
  nodes = np.unique(node_map)
  graph.add_nodes_from(nodes)
  graph.add_edges_from(zip(senders, receivers))
  distances_dict = nx.shortest_path_length(graph, source=n)
  distances = np.zeros((n,))

  distances[anchor_set] = 1.0
  for (node, distance) in distances_dict.items():
    if node < n:
      distances[node] = 1. / (1 + distance)

  return distances


def compute_bourgain_embedding(n: int,
                               senders: np.ndarray,
                               receivers: np.ndarray,
                               c=0.5) -> Any:
  """Computes the Bourgain embedding.

  Args:
    n: The number of nodes in the graph
    senders: The sender nodes of the graph
    receivers: The receiver nodes of the graph
    c: repetition factor

  Returns:
    Bourgain embedding.
  """
  anchor_sets = get_random_anchor_sets(n, c)
  k = len(anchor_sets)
  embedding = np.zeros((n, k))

  for i in range(k):
    distances = compute_distances_to_anchor_set(n, senders, receivers,
                                                anchor_sets[i])
    embedding[..., i] = distances

  return embedding
