# Lint as: python3
"""Experimental code for CARMAB algorithm"""

from absl import app
from absl import flags
import numpy as np

FLAGS = flags.FLAGS

flags.DEFINE_integer('max_episodes', 1000, 'Max episodes for the algorithm')
flags.DEFINE_integer('window', 3, 'Congestion window size')
flags.DEFINE_float('delta', 0.1, 'Confidence parameter in (0,1)')
flags.DEFINE_integer('k', 3, 'Number of actions')
flags.DEFINE_integer('horizon', 100000, 'Horizon')
flags.DEFINE_bool('print_action_error', False,
                  'Print error between learned and true base rewards.')

###################### KARP'S ALG ####################################


class Edge:

  def __init__(self, u, w):
    self.from_node = u
    self.weight = w


def addedge(edges, u, v, w):
  edges[v].append(Edge(u, w))


# calculates the shortest path
def shortestpath(edges, dp, previous, n):
  """Shortest path implementation."""

  # initializing all distances as -1
  for i in range(n + 1):
    for j in range(n):
      dp[i][j] = -1
      previous[i][j] = -1

  # shortest distance From first vertex
  # to itself consisting of 0 edges
  dp[0][0] = 0

  # filling up the dp table
  for i in range(1, n + 1):
    for j in range(n):
      for k in range(len(edges[j])):
        if dp[i - 1][edges[j][k].from_node] != -1:
          curr_wt = (dp[i - 1][edges[j][k].from_node] + edges[j][k].weight)
          if dp[i][j] == -1 or curr_wt < dp[i][j]:
            dp[i][j] = curr_wt
            previous[i][j] = edges[j][k].from_node


# Returns minimum value of average
# weight of a cycle in graph.
def min_avg_weight(edges, n):
  """Karp's algorithm."""
  dp = [[None] * n for i in range(n + 1)]
  previous = [[None] * n for i in range(n + 1)]
  shortestpath(edges, dp, previous, n)

  # array to store the avg values
  avg = [-1] * n

  # Compute average values for all
  # vertices using weights of
  # shortest paths store in dp.
  best_j = [-1 for i in range(n)]
  for i in range(n):
    if dp[n][i] != -1:
      for j in range(n):
        if dp[j][i] != -1:
          avg[i] = max(avg[i], (dp[n][i] - dp[j][i]) / (n - j))
          best_j[i] = j

  # Find minimum value in avg[]
  result = avg[0]
  best_i = 0
  for i in range(n):
    if (avg[i] != -1 and avg[i] < result):
      result = avg[i]
      best_i = i

  v = best_i

  # Reconstruct the path of length n from 0 to v
  edge_sequence = []
  node = v
  length = n
  while node != -1 and length >= 0:
    edge_sequence.append(node)
    node = previous[length][node]
    length = length - 1
  edge_sequence = list(reversed(edge_sequence))

  return edge_sequence


###################### END KARP'S STUFF ################################


def number_to_vector(num, k, window):
  """Convert a node number to a vector of actions."""
  vector = []
  for digit in range(window):
    shifted_num = num // k**digit
    vector.append(shifted_num % k)
  vector = list(reversed(vector))
  return vector


def vector_to_number(vector, k, window):
  """Convert a vector of actions to a number."""
  num = 0
  place_value = k**(window - 1)
  for value in vector:
    num = num + value * place_value
    place_value = place_value // k
  return num


def get_node_neighbors(node, k, window):
  """Get all neighbors of a given node."""
  neighbors = []
  node_vector = number_to_vector(node, k, window)
  for i in range(len(node_vector) - 1):
    node_vector[i] = node_vector[i + 1]
  for next_action in range(k):
    node_vector[window - 1] = next_action
    next_num = vector_to_number(node_vector, k, window)
    neighbors.append(next_num)
  return neighbors


def get_action_congestion(vector, action):
  congestion = 0
  for a in vector:
    if a == action:
      congestion = congestion + 1
  return congestion


def get_node_reward(node, k, window, r_hat):
  """Get reward of the new state."""
  vector = number_to_vector(node, k, window)
  action = vector[window - 1]
  congestion = get_action_congestion(vector, action)
  return r_hat[action][congestion]


def get_estimated_node_reward(node, k, window, r_hat, t_e, n_acc, e):
  """Get reward of the new state."""
  vector = number_to_vector(node, k, window)
  action = vector[window - 1]
  congestion = get_action_congestion(vector, action)
  adjusted_reward = r_hat[action][congestion] + max(
      0, 2 * np.sqrt(
          np.log(k * window * t_e / FLAGS.delta) /
          max(1, n_acc[e][action][congestion])))
  return adjusted_reward


def get_correct_node_reward(node, k, window, true_rewards, t):
  """Get reward of the new state."""
  vector = number_to_vector(node, k, window)
  action = vector[window - 1]
  congestion = get_action_congestion(vector, action)
  return true_rewards[t][action] / congestion


def get_cycle_from_sequence(edge_sequence, k, window, r_hat):
  """Find the best cycle in a sequence."""
  start_of_cycle = 0
  end_of_cycle = 1
  best_cycle_value = get_node_reward(edge_sequence[0], k, window, r_hat)
  for i in range(len(edge_sequence)):
    for j in range(i + 1, len(edge_sequence)):
      if edge_sequence[i] == edge_sequence[j]:
        cycle_value = 0
        for ii in range(i, j):
          cycle_value = cycle_value + get_node_reward(edge_sequence[ii], k,
                                                      window, r_hat)
        cycle_value = cycle_value / (j - i)
        if cycle_value > best_cycle_value:
          start_of_cycle = i
          end_of_cycle = j
          best_cycle_value = cycle_value

  cycle = []
  for i in range(start_of_cycle, end_of_cycle):
    cycle.append(edge_sequence[i])
  return cycle


def main(argv):
  episodes = FLAGS.max_episodes
  window = FLAGS.window
  k = FLAGS.k
  horizon = FLAGS.horizon

  seed = 0  # Controls randomness for costs
  rs = np.random.RandomState(seed)
  base_rewards = rs.uniform(low=0.0, high=1.0, size=k)
  true_rewards = [base_rewards for i in range(horizon)]
  for t in range(horizon):
    for i in range(k):
      noise = rs.uniform(low=-0.1, high=0.1)
      reward = true_rewards[t][i] + noise
      reward = min(1.0, reward)
      reward = max(0.0, reward)
      true_rewards[t][i] = reward
  opt_reward = []
  for t in range(horizon):
    max_r = 0.0
    for i in range(k):
      max_r = max(max_r, true_rewards[t][i])
    opt_reward.append(max_r)
  opt_cummulative = []
  opt_cummulative.append(opt_reward[0])
  for i in range(1, horizon):
    opt_cummulative.append(opt_cummulative[i - 1] + opt_reward[i])

  # Start Algorithm 1
  cummulative_reward = 0.0
  avg_rewards = []
  total_reward = [[0 for i in range(window + 1)] for i in range(k)]
  t = 1
  n = [[[0
         for i in range(window + 1)]
        for i in range(k)]
       for i in range(episodes)]
  # Start episode
  for e in range(episodes):
    if t >= horizon:
      break
    t_e = t
    # Initialize n_e(a,j) and compute N_e(a,j)
    n_acc = [[[0
               for i in range(window + 1)]
              for i in range(k)]
             for i in range(episodes)]
    for a in range(k):
      for j in range(window + 1):
        for s in range(e):
          n_acc[e][a][j] = n_acc[e][a][j] + n[s][a][j]
    # Compute r_hat(a,j)
    r_hat = [[0.0 for i in range(window + 1)] for i in range(k)]
    for a in range(k):
      for j in range(window + 1):
        if n_acc[e][a][j] == 0:
          r_hat[a][j] = 2
        else:
          r_hat[a][j] = total_reward[a][j] / n_acc[e][a][j]
    # Create graph for Karp's algorithm
    edges = [[] for i in range(k**window)]
    max_reward = 0
    for node in range(k**window):
      neighbors = get_node_neighbors(node, k, window)
      for neighbor in neighbors:
        value = get_estimated_node_reward(neighbor, k, window, r_hat, t_e,
                                          n_acc, e)
        max_reward = max(max_reward, value)
    for node in range(k**window):
      neighbors = get_node_neighbors(node, k, window)
      for neighbor in neighbors:
        value = get_estimated_node_reward(neighbor, k, window, r_hat, t_e,
                                          n_acc, e)
        addedge(edges, node, neighbor, max_reward + 1 - value)
    sequence = min_avg_weight(edges, k**window)
    cycle = get_cycle_from_sequence(sequence, k, window, r_hat)

    # Play the cycle
    action_it = 0
    while True:
      if t >= horizon:
        break
      vector = number_to_vector(cycle[action_it], k, window)
      action = vector[-1]
      congestion = get_action_congestion(vector, action)
      total_reward[action][congestion] = total_reward[action][
          congestion] + true_rewards[t][action] / congestion
      n[e][action][congestion] = n[e][action][congestion] + 1
      cummulative_reward = cummulative_reward + true_rewards[t][
          action] / congestion
      t = t + 1
      avg_rewards.append(cummulative_reward / t)
      if t % 100 == 0:
        print(t, cummulative_reward / t)
      if n[e][action][congestion] == max(1, n_acc[e][action][congestion]):
        break
      action_it = action_it + 1
      if action_it == len(cycle):
        action_it = 0
  if FLAGS.print_action_error:
    for j in range(k):
      print('Error between learned and true base reward for', j, 'is',
            r_hat[j][1] - base_rewards[j])


if __name__ == '__main__':
  app.run(main)

