import jax.numpy as jnp
from jax.lax import while_loop, cond


def is_empty(heap):
  return heap[1] == -1


def create_heap(item_size, max_heap_length):
  """
  Initializes an empty heap with max heap size and expecting
  item_size dimensional vectors.

  returns a tuple of the heap storage and length
  """
  return (jnp.inf * jnp.ones((max_heap_length, item_size)), -1)


def heappush(heap, item):
  storage, last_elem_idx = heap

  last_elem_idx = last_elem_idx + 1
  storage = storage.at[last_elem_idx].set(item)

  heap = (storage, last_elem_idx)

  return _siftdown(heap, 0, last_elem_idx)


def heappop(heap):
  storage, last_elem_idx = heap

  # Pop last element
  last_element = storage[last_elem_idx]
  storage = storage.at[last_elem_idx].set(jnp.inf)
  last_elem_idx = last_elem_idx - 1

  def ensure_heap_invariant(storage):
    return_item = storage[0]
    storage = storage.at[0].set(last_element)

    heap = (storage, last_elem_idx)
    heap = _siftup(heap, 0)

    return return_item, heap

  return cond(last_elem_idx == -1,
              lambda storage: (last_element, (storage, last_elem_idx)),
              ensure_heap_invariant,
              storage)


def _siftdown(heap, start_pos, pos):
  storage, last_elem_idx = heap
  new_item = storage[pos]

  # while pos > start_pos:
  #   parent_pos = (pos - 1) >> 1
  #   parent = storage[parent_pos]

  #   # Always compare on the first index
  #   if jnp.all(new_item[0] < parent[0]):
  #     storage = storage.at[pos].set(parent)
  #     pos = parent_pos
  #     continue

  #   break

  def cond_fun(args):
    _, pos, stop_cond = args
    return ~stop_cond & (pos > start_pos)

  def body_fun(args):
    storage, pos, _ = args

    parent_pos = (pos - 1) >> 1
    parent = storage[parent_pos]

    storage, pos, stop_cond = cond(
      jnp.all(new_item[0] < parent[0]),
      lambda: (storage.at[pos].set(parent), parent_pos, False),
      lambda: (storage, pos, True))

    return storage, pos, stop_cond

  storage, pos, _ = while_loop(
    cond_fun, body_fun, (storage, pos, False)
  )

  storage = storage.at[pos].set(new_item)

  return (storage, last_elem_idx)


def _siftup(heap, pos):
  storage, last_elem_idx = heap
  
  start_pos = pos
  new_item = storage[pos]

  # Bubble up the smaller child until hitting a leaf
  child_pos = 2 * pos + 1  # left child

  # while child_pos < last_elem_idx:

  #   right_pos = child_pos + 1

  #   if right_pos < last_elem_idx and not jnp.all(storage[child_pos][0] < storage[right_pos][0]):
  #     child_pos = right_pos

  #   storage = storage.at[pos].set(storage[child_pos])
  #   pos = child_pos
  #   child_pos = 2 * pos + 1

  def cond_fun(args):
    _, _, child_pos = args

    return child_pos < last_elem_idx

  def body_fun(args):
    storage, pos, child_pos = args

    right_pos = child_pos + 1

    child_pos = cond(
      (right_pos < last_elem_idx) & ~jnp.all(storage[child_pos][0] < storage[right_pos][0]),
      lambda: right_pos,
      lambda: child_pos,
    )

    storage = storage.at[pos].set(storage[child_pos])
    pos = child_pos
    child_pos = 2 * pos + 1

    return storage, pos, child_pos

  storage, pos, _ = while_loop(
    cond_fun,
    body_fun,
    (storage, pos, child_pos)
  )

  storage = storage.at[pos].set(new_item)

  heap = (storage, last_elem_idx)

  return _siftdown(heap, start_pos, pos)
