from queue import Queue
from copy import deepcopy
import numpy as np
from collections import defaultdict, Counter
import heapq

def bfs_heights(predecessor, t):
    heights = defaultdict(int)
    heights[t] = 0
    q = Queue()
    q.put(t)
    while not q.empty():
        u = q.get()
        for v in predecessor[u]:
            if not heights[v] > 0 and v != t:
                q.put(v)
                heights[v] = heights[u] + 1
    return heights

def CreatePreflow(graph, V, s, t, order='ASCEND'):
    rgraph = deepcopy(graph)
    predecessor = {i: [] for i in range(V)}
    for i in range(V):
        for j in graph[i]:
            predecessor[j].append(i)
    flow = {i: defaultdict(int) for i in range(V)}
    excess = defaultdict(int)
    # active_node_queue = Queue()
    active_nodes_queue = []

    for p in graph[s]:
        if graph[s][p] > 0:
            flow[s][p] = graph[s][p]
            rgraph[s][p], rgraph[p][s] = 0, graph[s][p]
            excess[p] = graph[s][p]
            # active_nodes_queue.put(p)
    heights = bfs_heights(predecessor, t)

    heights[s] = V
    for p in heights:
        if excess[p] > 0:
            if order == 'ASCEND':
                heapq.heappush(active_nodes_queue, (heights[p], p))
            else:
                heapq.heappush(active_nodes_queue, (-heights[p], p))
    t_side_counter = Counter([heights[p] for p in range(V)])
    t_side_counter.pop(V)
    t_side_counter.pop(0)
    threshold = max([key for key in t_side_counter if t_side_counter[key] > 0])
    T = set([i for i in range(V)])
    T.remove(s)
    return flow, rgraph, excess, heights, active_nodes_queue, T, t_side_counter, threshold

def FindCut2(T, graph, V):
    cuts = []
    S = [i for i in range(V) if i not in T]
    for p in S:
        for q in graph[p]:
            if graph[p][q] > 0 and q in T:
                cuts.append((p, q))
    return cuts

def SendExcessToSource(rgraph, flow, excess, S, s):
    # bfs from s
    predecessor = {i: [] for i in S}
    for i in S:
        for j in rgraph[i]:
            if j in S:
                predecessor[j].append(i)
    heights = bfs_heights(predecessor, s)

    push_counter, relabel_counter = 0, 0

    active_nodes_queue = []
    for p in S:
        if excess[p] > 0:
            heapq.heappush(active_nodes_queue, (-heights[p], p))

    while len(active_nodes_queue) > 0:
        h, active_node = heapq.heappop(active_nodes_queue)
        while excess[active_node] > 0:
            pushed = False
            neighbors = [q for q in rgraph[active_node] if (rgraph[active_node][q] > 0 and q in S)]
            assert len(neighbors) > 0
            for q in neighbors:
                if heights[q] == heights[active_node] - 1:
                    # push
                    push_counter += 1
                    pushed = True
                    push_flow = min(rgraph[active_node][q], excess[active_node])
                    assert push_flow > 0
                    flow[active_node][q] += max(0, push_flow - flow[q][active_node])
                    flow[q][active_node] = max(0, flow[q][active_node] - push_flow)
                    rgraph[active_node][q] -= push_flow
                    rgraph[q][active_node] += push_flow

                    excess[active_node] -= push_flow
                    if excess[q] == 0 and q != s:
                        # active_nodes_queue.put(q)
                        heapq.heappush(active_nodes_queue, (-heights[q], q))
                        # print("pushing:", heights[q], q, "pushed from:", active_node, "excess:", excess[q], "now has excess:", push_flow)
                    excess[q] += push_flow
                if excess[active_node] == 0:
                    break
            if not pushed:
                # relabel
                relabel_counter += 1
                heights[active_node] = min([heights[q] for q in neighbors]) + 1
                # print("relabeling {}: to {}".format(active_node,heights[active_node]))
    return push_counter, relabel_counter


def PushRelabel(graph, V, s, t, find_flow=True, heights=None, order='DESCEND'):
    flow, rgraph, excess, created_heights, active_nodes_queue, T, T_counter, threshold = CreatePreflow(graph, V, s, t, order)

    if heights is None:
        heights = created_heights

    push_counter, relabel_counter = 0, 0
    # while not active_nodes_queue.empty():
    while len(active_nodes_queue) > 0:
        # active_node = active_nodes_queue.get()
        h, active_node = heapq.heappop(active_nodes_queue)
        # print("processing:", h, active_node, "size of T", len(T))
        if active_node not in T:
            continue

        while excess[active_node] > 0 and active_node in T:
            pushed = False
            neighbors = [q for q in rgraph[active_node] if (rgraph[active_node][q] > 0 and q in T)]
            for q in neighbors:
                if heights[q] == heights[active_node] - 1:
                    # push
                    push_counter += 1
                    pushed = True
                    push_flow = min(rgraph[active_node][q], excess[active_node])
                    assert push_flow > 0
                    flow[active_node][q] += max(0, push_flow - flow[q][active_node])
                    flow[q][active_node] = max(0, flow[q][active_node] - push_flow)
                    rgraph[active_node][q] -= push_flow
                    rgraph[q][active_node] += push_flow

                    excess[active_node] -= push_flow
                    if excess[q] == 0 and q not in [s, t]:
                        if order == 'ASCEND':
                            heapq.heappush(active_nodes_queue, (heights[q], q))
                        elif order == 'DESCEND':
                            heapq.heappush(active_nodes_queue, (-heights[q], q))
                        #print("pushing:", heights[q], q, "pushed from:", active_node, "excess:", excess[q], "now has excess:", push_flow)
                    excess[q] += push_flow
                if excess[active_node] == 0:
                    break
            if not pushed:
                # relabel
                relabel_counter += 1
                T_counter[heights[active_node]] -= 1
                # new cut
                if T_counter[heights[active_node]] == 0:
                    threshold = heights[active_node] - 1
                    #print("shrinking threshold to:", threshold)
                    # T = {p for p in T if heights[p] <= threshold}
                    # T_counter = {key: T_counter[key] for key in T_counter if key <= threshold}
                    predecessor = {i: [] for i in range(V)}
                    for i in range(V):
                        for j in rgraph[i]:
                            if rgraph[i][j] > 0:
                                predecessor[j].append(i)
                    heights = bfs_heights(predecessor, t)
                    T = set(heights.keys())
                    T_counter = Counter(heights.values())

                else:
                    if len(neighbors) == 0:
                        T.remove(active_node)
                        break
                    heights[active_node] = min([heights[q] for q in neighbors]) + 1
                    # print("relabeling {}: to {}".format(active_node,heights[active_node]))
                    if heights[active_node] in T_counter:
                        T_counter[heights[active_node]] += 1
                    else:
                        T_counter[heights[active_node]] = 1
                    threshold = max(threshold, heights[active_node])
                    # print("current threshold:", threshold)
                    # heapq.heappush(active_nodes_queue, (heights[active_node], active_node))

    cuts = FindCut2(T, graph, V)
    S = {i for i in range(V) if i not in T}
    if find_flow:
        fix_push_counter, fix_relabel_counter = SendExcessToSource(rgraph, flow, excess, S, s)
        push_counter += fix_push_counter
        relabel_counter += fix_relabel_counter
    # print("flow value:", sum([flow[s][p] for p in graph[s]]))
    return flow, cuts, S, T, excess, push_counter, relabel_counter, heights

