from maxcut import H_graph, P_coloring
from collections import defaultdict
import networkx as nx

# from visualize_tree import visualize_tree


class ContrastiveTreeLocalSearch:
    def __init__(self, vertices, edges, coloring):
        self.points, self.constraints, self.T, self.root = self.reduce_maxcut_to_contrastive_tree(vertices, edges, coloring)
        self.constraints_by_vertex = self._index_constraints(self.constraints)
        self._last_ops = []

    @staticmethod
    def _index_constraints(constraints):
        d = defaultdict(list)
        for a, b, c, w in constraints:
            for v in (a, b, c):
                d[v].append((a, b, c, w))
        return d

    @staticmethod
    def reduce_maxcut_to_contrastive_tree(vertices, edges, coloring):
        points = list(vertices) + ["X", "X_", "Y", "Z"]
        constraints = []

        W = sum(w for _, _, w in edges)

        # Type A
        constraints += [("X", "X_", t, W) for t in (["Y", "Z"] + list(vertices))]

        # Type B
        constraints.append(("X", "Y", "Z", len(vertices) * W + len(vertices) + 1))

        # Type C
        for v in vertices:
            constraints.append(("Y", v, "X", W + 1))
            constraints.append(("Z", v, "X", W + 1))

        # Type D
        for u, v, w in edges:
            constraints.append((u, "X", v, w / 2))
            constraints.append((v, "X", u, w / 2))

        init = nx.DiGraph()
        internal_count = 0

        def new_internal():
            nonlocal internal_count
            internal_count += 1
            return f"i_{internal_count}"

        def build_subtree(nodes):
            if len(nodes) == 1:
                return nodes[0]
            mid = len(nodes) // 2
            parent = new_internal()
            left = build_subtree(nodes[:mid])
            right = build_subtree(nodes[mid:])
            init.add_edge(parent, left)
            init.add_edge(parent, right)
            return parent

        left_group = ["Y"] + [v for v in vertices if coloring[v] == 0]
        right_group = ["Z"] + [v for v in vertices if coloring[v] == 1]

        root_y = build_subtree(left_group)
        root_z = build_subtree(right_group)

        root_x = new_internal()
        init.add_edge(root_x, "X")
        init.add_edge(root_x, "X_")

        root_xy = new_internal()
        init.add_edge(root_xy, root_x)
        init.add_edge(root_xy, root_y)

        root_all = new_internal()
        init.add_edge(root_all, root_xy)
        init.add_edge(root_all, root_z)

        # visualize_tree(init, root_all)
        return points, constraints, init, root_all

    def _record_op(self, kind, a, b):
        self._last_ops.append((kind, a, b))

    def relocate(self, v, a, b):
        """
        Relocate a leaf v on the edge (a, b). Special case: a, b = (None, root)
        """

        self._last_ops.clear()

        parent = next(self.T.predecessors(v))
        self.T.remove_edge(parent, v)
        self._record_op("add", parent, v)

        sibling = next(self.T.successors(parent))
        self.T.remove_edge(parent, sibling)
        self._record_op("add", parent, sibling)

        grandparent = next(self.T.predecessors(parent), None)
        if grandparent:
            self.T.remove_edge(grandparent, parent)
            self._record_op("add", grandparent, parent)
            self.T.add_edge(grandparent, sibling)
            self._record_op("remove", grandparent, sibling)

        if a is None:
            self.T.add_edge(parent, b)
            self._record_op("remove", parent, b)
        else:
            self.T.remove_edge(a, b)
            self._record_op("add", a, b)
            self.T.add_edge(a, parent)
            self._record_op("remove", a, parent)
            self.T.add_edge(parent, b)
            self._record_op("remove", parent, b)

        self.T.add_edge(parent, v)
        self._record_op("remove", parent, v)

    def rollback(self):
        for kind, a, b in reversed(self._last_ops):
            if kind == "add":
                self.T.add_edge(a, b)
            else:
                self.T.remove_edge(a, b)
        self._last_ops.clear()

    def local_contribution(self, v, a, b):
        contribution = 0
        for i, j, k, w in self.constraints_by_vertex[v]:
            if nx.lowest_common_ancestor(self.T, i, k) == nx.lowest_common_ancestor(self.T, j, k):
                contribution -= w
        self.relocate(v, a, b)
        for i, j, k, w in self.constraints_by_vertex[v]:
            if nx.lowest_common_ancestor(self.T, i, k) == nx.lowest_common_ancestor(self.T, j, k):
                contribution += w
        self.rollback()
        return contribution

    def neighborhood_oracle(self):
        for v in self.points:
            if v in ["X", "X_", "Y", "Z"]:
                continue

            if nx.lowest_common_ancestor(self.T, "Y", "X") == nx.lowest_common_ancestor(self.T, v, "X"):
                a = next(self.T.predecessors("Z"))
                b = "Z"
            else:
                a = next(self.T.predecessors("Y"))
                b = "Y"

            if self.local_contribution(v, a, b) > 0:
                return v, a, b

        return None

    def run(self):
        move_count = 0
        while True:
            move = self.neighborhood_oracle()
            if move is None:
                break
            self.relocate(*move)
            move_count += 1
        return len(self.points), len(self.constraints), move_count


def main():
    for i in range(1, 7):
        edges = H_graph(i)
        vertices = set()
        for u, v, _ in edges:
            vertices.add(u)
            vertices.add(v)
        coloring = {v: P_coloring(v) for v in vertices}

        solver = ContrastiveTreeLocalSearch(vertices, edges, coloring)
        num_vertices, num_constraints, move_count = solver.run()
        print(f"n={i}, {num_vertices} vertices, {num_constraints} constraints, {move_count} moves")


if __name__ == "__main__":
    main()
