import gudhi

# This part of the code appears in the proof of Theorem 1

VG = [0, 1, 2, 3, 4, 5, 6, 7]
EG = [(0, 1), (1, 2), (1, 3), (2, 3), (3, 4), (4, 5), (4, 6), (4, 7), (6, 7)]
G_values = [1, 3, 2, 3, 4, 1, 2, 2]

VH = [0, 1, 2, 3, 4, 5, 6, 7]
EH = [(0, 1), (1, 2), (1, 3), (2, 3), (3, 4), (3, 5), (5, 6), (5, 7), (6, 7)]
H_values = [1, 3, 2, 4, 1, 3, 2, 2]

# Compute PDs for G
stG = gudhi.SimplexTree()
for i in range(0, len(VG)):
  current_v = VG[i]
  v_val = G_values[i]
  stG.insert([current_v], filtration=v_val)

# Adding edges
G_e_values = []
for i in range(0, len(EG)):
  current_e = EG[i]
  e_val = max(G_values[current_e[0]],G_values[current_e[1]])
  stG.insert(current_e, filtration=e_val)
  G_e_values.append(e_val)

stG.make_filtration_non_decreasing()
G_dgms = stG.persistence(min_persistence=-1, persistence_dim_max=True)
print(G_dgms)

# Compute PDs for H
stH = gudhi.SimplexTree()
for i in range(0, len(VH)):
  current_v = VH[i]
  v_val = H_values[i]
  stH.insert([current_v], filtration=v_val)

# Adding edges
H_e_values = []
for i in range(0, len(EH)):
  current_e = EH[i]
  e_val = max(H_values[current_e[0]],H_values[current_e[1]])
  stH.insert(current_e, filtration=e_val)
  H_e_values.append(e_val)

stH.make_filtration_non_decreasing()
H_dgms = stH.persistence(min_persistence=-1, persistence_dim_max=True)
print(H_dgms)

print(G_dgms == H_dgms)

# Given a vertex-based filtration, returns its backward function (as a filtration, in the sense of Proposition 4)
# We need to do some adjustments so that both filtration functions for G and H start at the same value, so there are some extra
# code passed along after calling the function filtration_to_backward
def filtration_to_backward(G, f):
  vertex_value = [-1 for _ in range(0, len(G[0]))]
  for i in range(0, len(G[1])):
    edge = G[1][i]
    e_val = f[1][i]
    v = edge[0]
    w = edge[1]

    # if vertex value has been assigned
    if vertex_value[v-1] != -1:
      # Only increase value if the edge instantiated is higher value
      if e_val > vertex_value[v-1]:
        vertex_value[v-1] = e_val
    else:
      vertex_value[v-1] = e_val

  # If the vertex is isolated, assign it its value here
  for i in range(0, len(vertex_value)):
    if vertex_value[i] == -1:
      vertex_value[i] = f[0][i]

  v_max = max(vertex_value)

  return v_max, vertex_value

G = [VG, EG]
H = [VH, EH]
fG = [G_values, G_e_values]
fH = [H_values, H_e_values]

vmax1, vval1 = filtration_to_backward(G, fG)
vmax2, vval2 = filtration_to_backward(H, fH)

v_max = max(vmax1, vmax2)
v_output1 = list(map(lambda x: round(-x + v_max + 1, 3), vval1))
v_output2 = list(map(lambda x: round(-x + v_max + 1, 3), vval2))

e_output1 = []
e_output2 = []
for i in range(0, len(G[1])):
  edge = G[1][i]
  e_val1 = fG[1][i]
  e_output1.append(round(-e_val1 + v_max + 1, 3))

for i in range(0, len(H[1])):
  edge = H[1][i]
  e_val2 = fH[1][i]
  e_output2.append(round(-e_val2 + v_max + 1, 3))

# Compute PDs for G using the Backward Filtration
stG = gudhi.SimplexTree()
for i in range(0, len(VG)):
  current_v = VG[i]
  v_val = G_values[i]
  stG.insert([current_v], filtration=v_output1[i])

# Adding edges
for i in range(0, len(EG)):
  current_e = EG[i]
  e_val = max(G_values[current_e[0]],G_values[current_e[1]])
  stG.insert(current_e, filtration=e_output1[i])

stG.make_filtration_non_decreasing()
G_dgms = stG.persistence(min_persistence=-1, persistence_dim_max=True)
print(G_dgms)

# Compute PDs for H using the Backward Filtration
stH = gudhi.SimplexTree()
for i in range(0, len(VH)):
  current_v = VH[i]
  v_val = H_values[i]
  stH.insert([current_v], filtration=v_output2[i])

# Adding edges
for i in range(0, len(EH)):
  current_e = EH[i]
  e_val = max(H_values[current_e[0]],H_values[current_e[1]])
  stH.insert(current_e, filtration=e_output2[i])

stH.make_filtration_non_decreasing()
H_dgms = stH.persistence(min_persistence=-1, persistence_dim_max=True)
print(H_dgms)

print(G_dgms == H_dgms)