from copy import deepcopy
import heapq
from PushRelabel import PushRelabel
from imagesegmentation import parseArgs, imageSegmentation
import os
from collections import defaultdict
import time


def CreateSeedFlowOnOldCut(new_graph, old_flow, S, T, old_excess, s, t):
    # rounding down
    for p in old_flow:
        for q in old_flow[p]:
            delta = max(old_flow[p][q] - new_graph[p][q], 0)
            if p not in [s,t]:
                old_excess[p] += delta
            if q not in [s,t]:
                old_excess[q] -= delta
            old_flow[p][q] -= delta
    # saturate the cut
    for p in S:
        for q in new_graph[p]:
            if q in T:
                delta = new_graph[p][q] - old_flow[p][q]
                if p != s:
                    old_excess[p] -= delta
                if q != t:
                    old_excess[q] += delta
                old_flow[p][q] = new_graph[p][q]

    return

def CreateRGraph(graph, flow, V):
    rgraph = {i: defaultdict(int) for i in range(V)}
    for p in range(V):
        for q in flow[p]:
            rgraph[p][q] = graph[p][q] - flow[p][q] + flow[q][p]
    return rgraph

def CreateSeedFlow(new_graph, old_flow, old_excess, s, t):
    # rounding down
    for p in old_flow:
        for q in old_flow[p]:
            delta = max(old_flow[p][q] - new_graph[p][q], 0)
            if p not in [s, t]:
                old_excess[p] += delta
            if q not in [s, t]:
                old_excess[q] -= delta
            old_flow[p][q] -= delta
    return

def CheckPseudoFlow(flow, graph, excess, V, s, t, S=None, T=None):
    predecessor = {i: [] for i in range(V)}
    for i in range(V):
        for j in graph[i]:
            predecessor[j].append(i)

    for p in flow:
        for q in flow[p]:
            assert flow[p][q] <= graph[p][q]
            assert flow[p][q] == 0 or flow[q][p] == 0

    for p in excess:
        if p not in [s, t]:
            excess_p = sum([flow[q][p] for q in predecessor[p]]) - sum([flow[p][q] for q in graph[p]])
            assert excess_p == excess[p]

    if S is not None and T is not None:
        for p in S:
            for q in graph[p]:
                if q in T:
                    assert flow[p][q] == graph[p][q]
                    assert flow[q][p] == 0
    return

def CreateAuxiliaryGraph(rgraph, excess, V, s, t, eta):
    super_s, super_t = V, V + 1
    auxiliary_graph = {i: defaultdict(int) for i in range(V + 2)}

    auxiliary_graph[super_s][s] = eta
    auxiliary_graph[t][super_t] = eta

    '''
    for p in excess:
        if p > 0:
            auxiliary_graph[super_s][p] = excess[p]
        if p < 0:
            auxiliary_graph[p][super_t] = -excess[p]
    '''
    for p in rgraph:
        for q in rgraph:
            auxiliary_graph[p][q] = rgraph[p][q]

    return auxiliary_graph, super_s, super_t


def FixExcessAndDeficit(flow, graph, excess, V, s, t, eta):
    rgraph = CreateRGraph(graph, flow, V)
    auxiliary_graph, super_s, super_t = CreateAuxiliaryGraph(rgraph, excess, V, s, t, eta)
    begin = time.time()
    auxiliary_flow, _, auxiliary_S, auxiliary_T, auxiliary_excess, auxiliary_push_counter, auxiliary_relabel_counter = (
        PushRelabel(auxiliary_graph, V + 2, super_s, super_t, False))
    end = time.time()
    print("pushes: ", auxiliary_push_counter, "relabels: ", auxiliary_relabel_counter)
    print(end - begin)
    return end - begin

def GetTrueEta(graph, flow, S, T):
    eta = 0
    for p in S:
        for q in graph[p]:
            if q in T:
                eta += (graph[p][q] - flow[p][q])
    for p in T:
        for q in graph[p]:
            if q in S:
                eta += flow[p][q]
    return eta

def SaturateFlow(graph, flow, excess, V, s, t, S, T, eta=60):
    rgraph = CreateRGraph(graph, flow, V)
    super_s, super_t = V, V + 1
    rgraph[super_s] = defaultdict(int)
    rgraph[super_t] = defaultdict(int)
    rgraph[super_s][s] = 10 * eta
    rgraph[t][super_t] = 10 * eta
    for p in S:
        if excess[p] > 0:
            rgraph[super_s][p] = excess[p]
    for q in T:
        if excess[q] < 0:
            rgraph[q][super_t] = -excess[p]
    begin = time.time()
    augmenting_flow, _, new_S, new_T, _, push_counter, relabel_counter, _ = PushRelabel(rgraph, V+2, super_s, super_t, True)
    end = time.time()

    for p in range(V):
        for q in graph[p]:
            if augmenting_flow[p][q] > 0:
                flow[p][q] = max(flow[p][q] + augmenting_flow[p][q] - flow[q][p], 0)
                flow[q][p] = max(flow[q][p] - augmenting_flow[p][q], 0)
                excess[p] -= augmenting_flow[p][q]
                excess[q] += augmenting_flow[p][q]

    excess[s], excess[t] = 0, 0
    new_S.remove(super_s)
    new_T.remove(super_t)
    assert t in new_T and s in new_S
    print("saturating the flow takes time: ", end - begin)
    print("push counter: ", push_counter, "relabel_counter", relabel_counter)
    return new_S, new_T, end - begin, push_counter, relabel_counter

def WarmStartPRTwo(args):
    folder, group, size, algo, loadseed = args.folder, args.group, args.size, args.algo, args.loadseed
    image_dir = folder + '/' + group + '_cropped'
    image_list = os.listdir(image_dir)
    V = size * size + 2
    SOURCE, SINK = V - 2, V - 1

    result_dir = folder + '/' + group + '_pr_results' + '/'
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    push_dir = result_dir + str(size) + '_push.txt'
    relabel_dir = result_dir + str(size) + '_relabel.txt'

    push_f = open(push_dir, 'w')
    push_f.write("image_name\tpr\twarm_start\tfind_cut\n")
    relabel_f = open(relabel_dir, 'w')
    relabel_f.write("image_name\tpr\twarm_start\tfind_cut\n")

    num_images = len(image_list)
    old_flow, old_S, old_T, old_excess = None, None, None, None
    # for i in range(2):
    for i in range(num_images):
        new_image = image_list[i]
        true_flow, _, true_S, true_T, true_excess, push_count, relabel_count, graph, pr_time = imageSegmentation(
            new_image, folder,
            group, (size, size),
            algo, loadseed)
        print("cold-start push and count: ", push_count, relabel_count)
        if old_flow is None:
            old_flow = deepcopy(true_flow)
            old_S = set(true_S)
            old_T = set(true_T)
            old_excess = deepcopy(true_excess)
            continue

        CreateSeedFlow(graph, old_flow, old_excess, SOURCE, SINK)
        CheckPseudoFlow(old_flow, graph, old_excess, V, SOURCE, SINK)
        old_excess[SOURCE], old_excess[SINK] = 0, 0
        total_excess = sum([abs(old_excess[i]) for i in range(V)])
        print("total excess/deficit to round down:" + str(total_excess / 2))
        eta = GetTrueEta(graph, old_flow, true_S, true_T)
        print("eta from cut saturating: ", eta)
        SaturateFlow(graph, old_flow, old_excess, V, SOURCE, SINK, eta)

        old_flows = deepcopy(true_flow)

    push_f.close()
    relabel_f.close()
    return

def AveragingFlow(new_graph, old_flows):
    average_flow = {p: defaultdict(int)for p in new_graph}
    excess = {p: 0 for p in new_graph}
    for flow in old_flows:
        for p in flow:
            for q in flow[p]:
                if flow[p][q] > 0:
                    # print(f'adding flow {p} to {q} with {flow[p][q]}')
                    average_flow[p][q] += flow[p][q]
    for p in average_flow:
        for q in average_flow[p]:
            average_flow[p][q] = int(average_flow[p][q] / len(old_flows))
            # if average_flow[p][q] > 0:
                # print(f'from {p} to {q} with flow {average_flow[p][q]}')
    for p in average_flow:
        for q in average_flow[p]:
                val = average_flow[p][q] - average_flow[q][p]
                average_flow[p][q] = max(0, val)
                average_flow[q][p] = average_flow[p][q] - val
                average_flow[p][q] = min(average_flow[p][q], new_graph[p][q])
                average_flow[q][p] = min(average_flow[q][p], new_graph[q][p])

    for p in average_flow:
        for q in average_flow[p]:
            excess[p] -= average_flow[p][q]
            excess[q] += average_flow[p][q]

    return average_flow, excess

def WarmStartPRThree(args):
    folder, group, size, algo, loadseed = args.folder, args.group, args.size, args.algo, args.loadseed
    image_dir = folder + '/' + group + '_cropped'
    image_list = os.listdir(image_dir)
    V = size * size + 2
    SOURCE, SINK = V - 2, V - 1

    result_dir = folder + '/' + group + '_pr_results' + '/'
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    push_dir = result_dir + str(size) + '_push.txt'
    relabel_dir = result_dir + str(size) + '_relabel.txt'

    push_f = open(push_dir, 'w')
    push_f.write("image_name\tpr\twarm_start\tfind_cut\n")
    relabel_f = open(relabel_dir, 'w')
    relabel_f.write("image_name\tpr\twarm_start\tfind_cut\n")

    num_images = len(image_list)
    old_flows = []
    for i in range(3):
        new_image = image_list[i]
        true_flow, _, true_S, true_T, true_excess, push_count, relabel_count, graph, pr_time = imageSegmentation(
            new_image, folder,
            group, (size, size),
            algo, loadseed)
        print("cold-start push and count: ", push_count, relabel_count)
        old_flows.append(true_flow)

    for i in range(0, 1):
        new_image = image_list[i]
        true_flow, _, true_S, true_T, true_excess, push_count, relabel_count, graph, pr_time = imageSegmentation(
            new_image, folder,
            group, (size, size),
            algo, loadseed)
        seed_flow, seed_excess = AveragingFlow(graph, old_flows)

        CheckPseudoFlow(seed_flow, graph, seed_excess, V, SOURCE, SINK)
        seed_excess[SOURCE], seed_excess[SINK] = 0, 0

        total_excess = sum([abs(seed_excess[i]) for i in range(V)])
        print("total excess/deficit to round down:" + str(total_excess / 2))
        eta = GetTrueEta(graph, seed_flow, true_S, true_T)
        print("eta from cut saturating: ", eta)
        SaturateFlow(graph, seed_flow, seed_excess, V, SOURCE, SINK, eta)

    push_f.close()
    relabel_f.close()
    return

def CreateTAuxiliaryGraph(rgraph, excess, T, t):
    super_T = deepcopy(T)
    super_s, super_t = -1, -2
    super_T.add(super_s)
    super_T.add(super_t)

    T_graph = {p: defaultdict(int) for p in super_T}
    for p in T:
        for q in rgraph[p]:
            if q in T:
                T_graph[p][q] = rgraph[p][q]

    total_excess = 0
    for p in T:
        if excess[p] > 0:
                T_graph[super_s][p] = excess[p]
                total_excess += excess[p]
        if excess[p] < 0:
            T_graph[p][super_t] = -excess[p]
    print("total excess: ", total_excess)
    T_graph[t][super_t] = total_excess + 1
    return T_graph, super_T, super_s, super_t

def FixTSideExcess(flow, graph, excess, S, T, t):
    rgraph = {p: defaultdict(int) for p in T}
    for p in T:
        for q in graph[p]:
            if q in T:
                rgraph[p][q] = graph[p][q] - flow[p][q] + flow[q][p]

    T_graph, super_T, super_s, super_t = CreateTAuxiliaryGraph(rgraph, excess, T, t)
    super_T = list(super_T)
    T_map = {p: i for i, p in enumerate(super_T)}
    super_V = len(super_T)
    converted_T_graph = {i: defaultdict(int) for i in range(super_V)}
    for p in T_graph:
        for q in T_graph[p]:
            converted_T_graph[T_map[p]][T_map[q]] = T_graph[p][q]
    begin = time.time()
    T_flow, T_cut, T_0, T_1, T_excess, T_push_counter, T_relabel_counter, _ = PushRelabel(converted_T_graph, super_V, T_map[super_s], T_map[super_t])
    end = time.time()
    T_0.remove(T_map[super_s])
    T_1.remove(T_map[super_t])
    for p in T_flow:
        for q in T_flow[p]:
            if T_flow[p][q] > 0:
                real_p, real_q = super_T[p], super_T[q]
                if real_p >= 0 and real_q >= 0:
                    flow[real_p][real_q] += max(0, T_flow[p][q] - flow[real_q][real_p])
                    flow[real_q][real_p] = max(0, flow[real_q][real_p] - T_flow[p][q])
                    excess[real_p] -= T_flow[p][q]
                    if real_q != t:
                        excess[real_q] += T_flow[p][q]

    for p in T_0:
        S.add(super_T[p])
        T.remove(super_T[p])
    return S, T, end - begin, T_push_counter, T_relabel_counter

def mirror(flow):
    mirror_flow = {p: defaultdict(int) for p in flow}
    for p in flow:
        for q in flow[p]:
            if flow[p][q] > 0:
                mirror_flow[q][p] = flow[p][q]
    return mirror_flow

def FixSSideDeficit(flow, graph, excess, S, T, s):
    mirror_flow = mirror(flow)
    mirror_graph = mirror(graph)
    mirror_excess = {p: -excess[p] for p in excess}
    T, S, S_time, S_push_counter, S_relabel_counter = FixTSideExcess(mirror_flow, mirror_graph, mirror_excess, T, S, s)
    for p in graph:
        for q in graph[p]:
            flow[p][q] = mirror_flow[q][p]
    for p in excess:
        excess[p] = -mirror_excess[p]
    return T, S, S_time, S_push_counter, S_relabel_counter

def FixTSideExcessOneByOne(flow, graph, excess, S, T, t, eta=10000):
    super_s, super_t = -1, -2
    super_T = list(T) + [-1, -2]
    T_map = {p: i for i, p in enumerate(super_T)}
    T_graph = {p: defaultdict(int) for p in range(len(super_T))}
    T_graph[T_map[t]][T_map[super_t]] = eta
    print("old T size: ", len(T))
    for p in T:
        for q in graph[p]:
            if q in T:
                T_graph[T_map[p]][T_map[q]] = graph[p][q] - flow[p][q] + flow[q][p]
    for p in excess:
        if excess[p] < 0 and p in T:
            T_graph[T_map[p]][T_map[super_t]] = -excess[p]

    push_counter, relabel_counter, total_time = 0, 0, 0
    for p in excess:
        if excess[p] > 0 and p in T:
            print("node: ", p, "excess: ", excess[p])
            T_graph[T_map[super_s]][T_map[p]] = excess[p]
            begin = time.time()
            T_flow, T_cut, T_0, T_1, T_excess, T_push_counter, T_relabel_counter, heights = PushRelabel(T_graph, len(T) + 2, T_map[super_s], T_map[super_t])
            end = time.time()
            total_time += (end - begin)
            push_counter += T_push_counter
            relabel_counter += T_relabel_counter
            print("fixed excess: ", T_flow[T_map[super_s]][T_map[p]])
            print(f'pushes: {T_push_counter}, relabels: {T_relabel_counter}')
            T_0.remove(T_map[super_s])
            T_1.remove(T_map[super_t])

            real_heights = {super_T[p]: heights[p] for p in heights}
            for p in T_flow:
                for q in T_flow[p]:
                    if T_flow[p][q] > 0:
                        real_p, real_q = super_T[p], super_T[q]
                        if real_p >= 0 and real_q >= 0:
                            flow[real_p][real_q] += max(0, T_flow[p][q] - flow[real_q][real_p])
                            flow[real_q][real_p] = max(0, flow[real_q][real_p] - T_flow[p][q])
                            excess[real_p] -= T_flow[p][q]
                            if real_q != t:
                                excess[real_q] += T_flow[p][q]

            for p in T_0:
                S.add(super_T[p])
                T.remove(super_T[p])

            super_T = list(T) + [-1, -2]
            T_map = {p: i for i, p in enumerate(super_T)}
            T_graph = {T_map[p]: defaultdict(int) for p in super_T}
            T_graph[T_map[t]][T_map[super_t]] = eta
            heights = {T_map[p]: real_heights[p] for p in T}
            heights[T_map[super_s]] = len(T)
            heights[T_map[super_t]] = 0

            print("old T size: ", len(T))
            for p in T:
                for q in graph[p]:
                    if q in T:
                        T_graph[T_map[p]][T_map[q]] = graph[p][q] - flow[p][q] + flow[q][p]
            for p in excess:
                if excess[p] < 0 and p in T:
                    T_graph[T_map[p]][T_map[super_t]] = -excess[p]

    print("time spent on warmstart: ", total_time)
    print("push: ", push_counter, "relabel: ", relabel_counter)
    print("new T size: ", len(T))
    return total_time, push_counter, relabel_counter

def CheckTSide(flow, graph, S, T, excess):
    for p in S:
        for q in graph[p]:
            if q in T:
                # print(p, q, flow[p][q], graph[p][q])
                assert flow[p][q] == graph[p][q]
                assert flow[q][p] == 0
    for p in T:
        # print(p, excess[p])
        assert excess[p] <= 0
    return

def WarmStartPR(args):
    folder, group, size, algo, loadseed = args.folder, args.group, args.size, args.algo, args.loadseed
    image_dir = folder + '/' + group + '_cropped'
    image_list = os.listdir(image_dir)
    V = size * size + 2
    SOURCE, SINK = V - 2, V - 1

    result_dir = folder + '/' + group + '_pr_results' + '/'
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    push_dir = result_dir + str(size) + '_push.txt'
    relabel_dir = result_dir + str(size) + '_relabel.txt'
    time_dir = result_dir + str(size) + '_time.txt'

    push_f = open(push_dir, 'w')
    push_f.write("image_name\tpr\twarm_start\tfind_cut\tfix_t\tfix_s\n")
    relabel_f = open(relabel_dir, 'w')
    relabel_f.write("image_name\tpr\twarm_start\tfind_cut\tfix_t\tfix_s\n")
    time_f = open(time_dir, 'w')
    time_f.write("image_name\tpr\twarm_start\tfind_cut\tfix_t\tfix_s\n")

    num_images = len(image_list)
    flow, S, T, excess = None, None, None, None
    num_average = 1
    old_flows = []
    for i in range(num_images):
        new_image = image_list[i]
        true_flow, _, true_S, true_T, true_excess, push_count, relabel_count, graph, pr_time = imageSegmentation(
            new_image, folder,
            group, (size, size),
            algo, loadseed)
        if len(old_flows) < num_average:
            old_flows.append(true_flow)
            S = true_S
            T = true_T
        else:
            flow, excess = AveragingFlow(graph, old_flows)
            excess[SOURCE], excess[SINK] = 0, 0
            total_excess = sum([abs(excess[i]) for i in range(V)])
            print("total excess/deficit to round down:" + str(total_excess / 2))
            CheckPseudoFlow(flow, graph, excess, V, SOURCE, SINK)
            eta = GetTrueEta(graph, flow, true_S, true_T)
            S, T, saturate_time, saturate_pushes, saturate_relabels = SaturateFlow(graph, flow, excess, V, SOURCE, SINK, S, T, eta)
            CheckPseudoFlow(flow, graph, excess, V, SOURCE, SINK)

            eta = GetTrueEta(graph, flow, true_S, true_T)
            print(f"now cut {eta} away from being saturated")
            S_def = sum([-excess[p] for p in true_S if excess[p] < 0])
            T_exc = sum([excess[p] for p in true_T if excess[p] > 0])
            print(f'still need to resolve {S_def} in S and {T_exc} in T')
            S, T, T_time, T_pushes, T_relabels = FixTSideExcess(flow, graph, excess, S, T, SINK)
            print(f'T side excess takes time: {T_time}, pushes: {T_pushes}, relabels: {T_relabels}')
            CheckPseudoFlow(flow, graph, excess, V, SOURCE, SINK, S, T)
            T, S, S_time, S_pushes, S_relabels = FixSSideDeficit(flow, graph, excess, S, T, SOURCE)
            print(f'S side deficit takes time: {S_time}, pushes: {S_pushes}, relabels: {S_relabels}')
            CheckPseudoFlow(flow, graph, excess, V, SOURCE, SINK, S, T)
            CheckPseudoFlow(flow, graph, excess, V, SOURCE, SINK, true_S, true_T)

            old_flows.pop(0)
            old_flows.append(true_flow)
            S = true_S
            T = true_T

            time_f.write(new_image.split('.')[0] + '\t')
            time_f.write(f'{pr_time}\t{saturate_time + T_time + S_time}\t{saturate_time}\t{T_time}\t{S_time}\n')
            time_f.flush()
            push_f.write(new_image.split('.')[0] + '\t')
            push_f.write(f'{push_count}\t{saturate_pushes + T_pushes + S_pushes}\t{saturate_pushes}\t{T_pushes}\t{S_pushes}\n')
            push_f.flush()
            relabel_f.write(new_image.split('.')[0] + '\t')
            relabel_f.write(f'{relabel_count}\t{saturate_relabels + T_relabels + S_relabels}\t{saturate_relabels}\t{T_relabels}\t{S_relabels}\n')
            relabel_f.flush()

    time_f.close()
    push_f.close()
    relabel_f.close()
    return

if __name__ == "__main__":
    args = parseArgs()
    WarmStartPR(args)
