import random
import multiprocessing as mp
import threading
import warnings
from collections import deque, OrderedDict, defaultdict as DefaultDict
from queue import Empty, Full

import torch
import torch.nn as nn
import numpy as np
import networkx as nx
import gurobipy as gp
from tqdm.autonotebook import tqdm

from matplotlib import colormaps
import plotly.graph_objects as go
import plotly.express as px

from poly import Polyhedron, encode_bv

gp.disposeDefaultEnv()


def get_colors(data, cmap="viridis", **kwargs):
    if not data:
        return []
    a = np.asarray(data)
    a = a - np.min(a)
    am = np.max(a)
    a = a / (am if am > 0 else 1)
    a = colormaps[cmap](a)
    a = (a * 255).astype(int)
    return [f"#{x[0]:02x}{x[1]:02x}{x[2]:02x}" for x in a]


def id(x):
    return x


class BVNode:
    def __init__(self, key):
        self.key = key  ## Key of bv being set
        self.left = None  ## for all nodes in subtree, bv[0, key] = -1
        self.middle = None  ## ...bv[0, key] = 0
        self.right = None  ## ...bv[0, key] = 1

    def get_child(self, bv):
        # return either BVNode
        if bv[0, self.key] == -1:
            next_node = self.left
        elif bv[0, self.key] == 0:
            next_node = self.middle
        elif bv[0, self.key] == 1:
            next_node = self.right
        return next_node

    def set_child(self, bv, node):
        if bv[0, self.key] == -1:
            self.left = node
        elif bv[0, self.key] == 0:
            self.middle = node
        elif bv[0, self.key] == 1:
            self.right = node


## Trie
## Each edge in the tree sets a dimension to a value
## Leaf nodes are just indices of polyhedra in index2poly
## TODO: Replace hashmaps in Decomp class with this
class PolyManager:
    def __init__(self):
        self.root = BVNode(0)
        self.index2poly = list()
        self.poly2index = dict()

    def _get_bv(self, bv):
        node = self.root
        while isinstance(node, BVNode):
            node = node.get_child(bv)
        if isinstance(node, int):
            poly = self.index2poly[node]
            if (poly.bv == bv).all():
                return poly
            else:
                return NotImplemented
            return self.index2poly[node]
        return None

    def add_bv(self, bv, index):
        node = self.root
        child = self.root.get_child(bv)
        while isinstance(child, BVNode):
            node = child
            child = node.get_child(bv)
        if child is None:
            node.set_child(bv, index)
        elif isinstance(child, int):
            child_bv = self.index2poly[child].bv
            if (child_bv[0, node.key :] == bv[0, node.key :]).all():
                return self.index2poly[child]
            elif child_bv[0, node.key] == bv[0, node.key]:
                for i in range(node.key + 1, len(bv)):
                    if child_bv[0, i] != bv[0, i]:
                        break
                new_bvnode = BVNode(i)
                new_bvnode.set_child(child_bv, child)
                new_bvnode.set_child(bv, index)
        else:
            raise KeyError

    def contains_bv(self, bv):
        return self.get(bv) is not None

    def __iter__(self):
        return iter(self.index2poly)

    def len(self):
        return len(self.index2poly)

    def get_poly_attrs(self, attrs):
        return {attr: [getattr(poly, attr) for poly in self] for attr in attrs}


class Decomp:
    ## TODO: Phase out embedder
    def __init__(self, net, embedder=None):
        self.net = net
        self.embedder = embedder or id

        self.polys = OrderedDict()
        self.points = dict()
        self.poly2points = DefaultDict(set)
        self.root = BVNode(0)
        self.index2poly = list()
        self.poly2index = dict()

        net_layers = list(net.layers.values())
        self.bv_layers = [
            i
            for i, (layer, next_layer) in enumerate(zip(net_layers[:-1], net_layers[1:]))
            if isinstance(next_layer, nn.ReLU)
        ]

        x = torch.zeros((1,) + net.input_shape, device=net.device, dtype=net.dtype)
        self.bvi2maski = []
        for i, layer in enumerate(self.net.layers.values()):
            x = layer(x)
            if i in self.bv_layers:
                it = np.nditer(x.detach().cpu().numpy(), flags=["multi_index"])
                for _ in it:
                    self.bvi2maski.append((i, it.multi_index))

    @property
    def dim(self):
        return np.prod(self.net.input_shape)

    @torch.no_grad()
    def bv_iterator(self, batch):
        x = (
            torch.tensor(batch, device=self.net.device, dtype=self.net.dtype)
            if isinstance(batch, np.ndarray)
            else batch
        ).reshape((-1, *self.net.input_shape))
        for i, layer in enumerate(self.net.layers.values()):
            x = layer(x)
            if i in self.bv_layers:
                # yield torch.sign(x)
                yield torch.sign(x)  # * (torch.abs(x) < 1e-12)
                if i == self.bv_layers[-1]:
                    break

    def point2bv(self, batch):
        return torch.hstack(list(self.bv_iterator(batch)))

    def add_polyhedron(self, p, overwrite=False):
        if p not in self:
            self.poly2index[p] = len(self)
            self.index2poly.append(p)
            self.polys[p.tag] = p
        if overwrite:
            self.polys[p.tag] = p
            self.index2poly[self.poly2index[p]] = p
        return self.polys[p.tag]

    @torch.no_grad()
    def point2poly(self, point):
        bv = self.point2bv(point)
        if bv in self:
            p = self[bv]
            return p
        else:
            return self.bv2poly(bv, check_exists=False)

    @torch.no_grad()
    def add_point(self, data, **kwargs):
        p = self.point2poly(self.embedder(data.detach()))
        p = self.add_polyhedron(p)
        for key, value in kwargs.items():
            if not hasattr(p, key):
                setattr(p, key, value)
        self.points[tuple(data.detach().cpu().numpy().flatten())] = p
        self.poly2points[p.tag].add(tuple(data.detach().cpu().numpy().flatten()))
        return p

    def bv2poly(self, bv, check_exists=True):
        if check_exists and bv in self:
            return self[bv]
        else:
            return Polyhedron(self.net, bv, decomp=None)

    def add_bv(self, bv, check_exists=True):
        p = self.bv2poly(bv, check_exists=check_exists)
        p = self.add_polyhedron(p)
        return p

    def poly_grid(self, bounds=2, res=100):
        _, _, inputVal = self.net.get_grid(bounds, res)
        bar = tqdm(torch.Tensor(inputVal).to(self.net.device, self.net.dtype), desc="Creating Polyhedra From Grid")
        for data in bar:
            self.add_point(data[None, :])

    def __getitem__(self, key):
        if isinstance(key, Polyhedron) and key.tag in self.polys:
            return self.polys[key.tag]
        elif isinstance(key, int) and key in self.polys:
            return self.polys[key]
        elif isinstance(key, np.ndarray):
            tag = encode_bv(key.squeeze())
            if tag in self.polys:
                return self.polys[tag]
        elif isinstance(key, torch.Tensor):
            tag = encode_bv(key.squeeze().detach().cpu().numpy().squeeze())
            if tag in self.polys:
                return self.polys[tag]
        elif isinstance(key, str):
            raise NotImplementedError
            matches = [p for p in self if str(p) == key]
            if len(matches) == 1:
                return matches[0]
            elif len(matches) > 1:
                raise ValueError(f"Multiple Polyhedra with key {key} in decomp")
        raise KeyError(f"Polyhedron with key {key} not in decomp")

    def get_points_poly(self, key):
        data_point = tuple(self.embedder(key).detach().cpu().numpy().squeeze())
        if data_point in self.points:
            return self.points[data_point]
        bv = self.point2bv(key)
        if bv in self:
            return self[bv]
        else:
            raise ValueError("Polyhedron not in decomp")

    def __contains__(self, key):
        try:
            self[key]
            return True
        except KeyError:
            return False
        # try:
        #     self.get_points_poly(key)
        #     return True
        # except KeyError:
        #     return False

    def __iter__(self):
        for p in self.index2poly:
            yield p

    def __len__(self):
        return len(self.index2poly)

    @torch.no_grad()
    def path(self, left, right, batch_size=256):
        def faster_diff(a, b):
            totals = torch.zeros(a.shape[0], device=self.net.device)
            for mask_a, mask_b in zip(self.bv_iterator(a), self.bv_iterator(b)):
                totals += (mask_a * mask_b == -1).sum(axis=1)
                if (totals > 1).all():
                    return totals
            return totals.flatten()

        left, right = self.embedder(left), self.embedder(right)

        all_points = [left + ((right - left) * i / batch_size) for i in range(batch_size + 1)]
        lefts, rights = torch.vstack(all_points[:-1]), torch.vstack(all_points[1:])
        lastpoints, ends = lefts.clone(), rights.clone()
        polys = [[(i, self.point2poly(i))] for i in list(lefts)]
        # polys.extend([(i, Polyhedron(self.net, self.point2bv(i))) for i in list(rights)])

        startpoly = self.point2poly(left)
        endpoly = self.point2poly(right)
        total_difference = endpoly.nflips(startpoly)
        print("Total Difference:", total_difference)
        pbar = tqdm(total=round(torch.linalg.norm(lefts[0] - rights[0]).item(), 4), desc="Path Progress")
        polys_indices = torch.arange(batch_size, device=self.net.device)
        skipped = 0
        # last_rights, last_lefts = torch.zeros_like(rights), torch.zeros_like(lefts)
        while True:
            remain_rows = faster_diff(lastpoints, ends) > 0  # TODO: Should be able to speed this up, ends never changes
            remain_rows_sum = remain_rows.sum()
            if remain_rows_sum == 0:
                break

            # ## Clean Version
            # means = (lefts + rights) / 2
            # nflips = faster_diff(lastpoints, means)
            # lefts[nflips == 0] = means[nflips == 0]
            # lastpoints[nflips == 1] = means[nflips == 1]
            # rights[nflips == 1] = ends[nflips == 1]
            # nflips2 = faster_diff(lefts, means)
            # lefts[(nflips > 1) & (nflips2 == 0)] = means[(nflips > 1) & (nflips2 == 0)]
            # rights[(nflips > 1) & (nflips2 > 0)] = means[(nflips > 1) & (nflips2 > 0)]

            ## Slicing Version
            # polys_indices = polys_indices[remain_rows]
            # lefts = lefts[remain_rows]
            # rights = rights[remain_rows]
            # lastpoints = lastpoints[remain_rows]
            # ends = ends[remain_rows]
            # means = (lefts + rights) / 2
            # nflips = faster_diff(lastpoints, means)
            # lefts[nflips == 0] = means[nflips == 0]
            # lastpoints[nflips == 1] = means[nflips == 1]
            # rights[nflips == 1] = ends[nflips == 1]
            # nflips2 = faster_diff(lefts, means)
            # lefts[(nflips > 1) & (nflips2 == 0)] = means[(nflips > 1) & (nflips2 == 0)]
            # rights[(nflips > 1) & (nflips2 > 0)] = means[(nflips > 1) & (nflips2 > 0)]

            ## Indexing Version
            means = (lefts + rights) / 2
            # if (means == lefts).all(dim=1).sum() != 0:
            #     print("Warning: Means == Lefts")
            #     print(means[(means == lefts).all(dim=1).nonzero().flatten()])
            #     raise ValueError("Means == Lefts")
            # if (means == rights).all(dim=1).sum() != 0:
            #     print("Warning: Means == Rights")
            #     print(means[(means == rights).all(dim=1).nonzero().flatten()])
            # raise ValueError("Means == Rights")

            ## Inplace Version
            nflips = faster_diff(lastpoints, means)
            ## If no flips, move left to mean plus a little
            lefts[[remain_rows & (nflips == 0)]] = means[[remain_rows & (nflips == 0)]].clone()  # + 1e-6
            ## If one flip, log the point and move the right to the end
            lastpoints[[remain_rows & (nflips == 1)]] = means[[remain_rows & (nflips == 1)]].clone()
            lefts[[remain_rows & (nflips == 1)]] = means[[remain_rows & (nflips == 1)]].clone()  ##! Newly changed
            rights[[remain_rows & (nflips == 1)]] = ends[[remain_rows & (nflips == 1)]].clone()
            ## If more than one flip
            nflips2 = faster_diff(lefts, means)
            ## If no flips between left and mean, we can move the left to the mean plus a little
            lefts[[remain_rows & (nflips > 1) & (nflips2 == 0)]] = means[
                [remain_rows & (nflips > 1) & (nflips2 == 0)]
            ].clone()  # + 1e-6
            ## If more than one flip between left and mean, we can move the right to the mean minus a little
            rights[[remain_rows & (nflips > 1) & (nflips2 > 0)]] = means[
                [remain_rows & (nflips > 1) & (nflips2 > 0)]
            ].clone()  # - 1e-6

            for i in torch.argwhere(nflips == 1).flatten():
                polys[polys_indices[i]].append((means[i], self.point2poly(means[i])))

            too_close = (torch.linalg.norm(lefts - rights, dim=1) < 4e-6) & (nflips != 1)
            for i in torch.argwhere(too_close).flatten():
                right_poly = self.point2poly(rights[i])
                ncrossed = self.point2poly(lastpoints[i]).nflips(right_poly)
                if ncrossed > 1:
                    polys[polys_indices[i]].append((rights[i], right_poly))
                    warnings.warn(f"Warning: Crossed {ncrossed} Boundaries")
                    skipped += ncrossed
                    warnings.warn("Warning: Crossed Multiple Boundaries")

            lefts[too_close] = rights[too_close].clone()  # + 1e-6
            lastpoints[too_close] = lefts[too_close].clone()
            rights[too_close] = ends[too_close].clone()

            remaining_dists = torch.linalg.norm(lastpoints - ends, dim=1)
            worst_index = remaining_dists.argmax()
            pbar.n = round(pbar.total - remaining_dists[worst_index].item(), 4)
            pbar.set_postfix_str(
                # f"# Polys: {sum(len(p) for p in polys)}, # Worst Polys: {len(polys[worst_index])}, Worst LR Dist: {torch.linalg.norm(lefts[worst_index] - rights[worst_index]):.4E}, nflips1: {nflips[worst_index]}, nflips2: {nflips2[worst_index]}, remainrow: {remain_rows[worst_index]}, d: {faster_diff(means[worst_index, None], rights[worst_index, None]).item()}, d: {faster_diff(lastpoints[worst_index, None], rights[worst_index, None]).item()} "
                f"# Polys: {sum(len(p) for p in polys)} # Alive: {remain_rows_sum} # Skipped: {skipped}",
                refresh=False,
            )
        polys = [x for y in polys for x in y]
        polys.sort(key=lambda x: torch.linalg.norm(left - x[0]))
        polys = [p1 for p1, p2 in zip(polys[:-1], polys[1:]) if p1[1].nflips(p2[1]) > 0]
        return polys

    def get_graph(self, node_color=None, edge_color=None, node_size=None, cmap="viridis", match_locations=False):
        G = nx.Graph()

        print("Adding Nodes")
        if match_locations:
            if next(iter(self)).W.shape[1] != 2:
                raise ValueError("Polyhedra must be 2D to match locations")
            xs = np.array([poly.center[0] for poly in self])
            ys = np.array([poly.center[1] for poly in self])
            maxc = max(np.max(np.abs(xs)), np.abs(np.max(ys)))

            for poly in self:
                G.add_node(
                    poly.tag,
                    x=poly.center[0] * 1000 / maxc,
                    y=-poly.center[1] * 1000 / maxc,
                    label=" ",  # str(poly),
                    title=str(poly),
                    physics=False,
                )
        else:
            for poly in self:
                G.add_node(poly.tag, label=" ", title=f"{str(poly)}\nPoints: {len(self.poly2points[poly.tag])}")
        nodes = list(G.nodes)
        print("Adding Edges")
        for i in tqdm(range(len(nodes))):
            poly1 = self.polys[nodes[i]]
            for j in range(i + 1, len(nodes)):
                poly2 = self.polys[nodes[j]]
                shared = poly1.nflips(poly2)
                num_shared = len(shared)
                # if num_shared > 30:
                #     continue
                W1v = poly1.W.flatten()
                W2v = poly2.W.flatten()
                Wsim = (
                    (torch.dot(W1v, W2v) / (torch.linalg.norm(W1v) * torch.linalg.norm(W2v))).item()
                    if torch.linalg.norm(W1v) * torch.linalg.norm(W2v) > 0
                    else 0.99
                )
                # shared = poly1.common_halfspaces(poly2)
                # shared = set(poly1.shi).intersection(set(poly2.shi))
                G.add_edge(
                    nodes[i],
                    nodes[j],
                    title=f"{num_shared} Shared\nSimilarity: {Wsim:.2f}\n{shared}",
                    label=None,
                    value=4 - num_shared,
                    weight=4 - num_shared,
                    Wsim=Wsim,
                )
        if edge_color == "Wsim":
            colors = get_colors([G.edges[edge]["Wsim"] for edge in G.edges], cmap=cmap)
            for c, edge in zip(colors, G.edges):
                G.edges[edge]["color"] = c
        if node_color == "Wl2":
            colors = get_colors([poly.Wl2 for poly in self], cmap=cmap)
            for c, poly in zip(colors, nodes):
                G.nodes[poly]["color"] = c
        elif node_color == "volume":
            colors = get_colors([poly.ch.volume for poly in self], cmap=cmap)
            for c, poly in zip(colors, nodes):
                G.nodes[poly]["color"] = c

        if node_size == "volume":
            sizes = [poly.ch.volume for poly in self]
            maxsize = max(sizes)
            for size, poly in zip(sizes, nodes):
                G.nodes[poly]["size"] = (10 + 1000 * size / maxsize) ** 1
        return G

    def clean_data(self):
        for poly in self:
            poly.clean_data()

    @torch.no_grad()
    def adjacent_polyhedra(self, poly):
        ps = set()
        shis = poly.shis
        for shi in shis:
            if poly.bv[0, shi] == 0:
                continue
            bv = poly.bv.clone()
            bv[0, shi] = -bv[0, shi]
            self.add_bv(bv)
        return ps

    def get_env(self):
        env = gp.Env(logfilename="", empty=True)
        env.setParam("OutputFlag", 0)
        env.setParam("LogToConsole", 0)
        env.start()
        return env

    def searcher_single(
        self,
        start=None,
        max_depth=float("inf"),
        max_polys=float("inf"),
        queue=None,
        pop=lambda x: x.pop(),
        bound=1e5,
        collect_info=None,
        **kwargs,
    ):
        if len(self) > 0:
            raise ValueError("Decomposition already has polyhedra")
        if queue is None:
            queue = deque()
        if start is None:
            start = self.add_point(torch.zeros(self.net.input_shape, device=self.net.device, dtype=self.net.dtype))
        if isinstance(start, torch.Tensor):
            start = self.add_point(start)
        if (start.bv == 0).any():
            raise ValueError("Start point must not be on a hyperplane")
        # num_relus = start.halfspaces.shape[0]
        start._shis = start.get_shis(bound=bound, **kwargs)
        env = self.get_env()
        for shi in start.shis:
            queue.append((start, shi, 1))
        rolling_average = len(start.shis)
        pbar = tqdm(desc="Search Progress", mininterval=5)
        pbar.n = 1
        # poly_infos = {}
        try:
            while queue and len(self) < max_polys:
                node, shi, depth = pop(queue)
                new_bv = node.bv.clone()
                try:
                    assert new_bv[0, shi] != 0
                except RuntimeError:
                    breakpoint()
                new_bv[0, shi] *= -1
                if new_bv in self:
                    continue

                p = self.bv2poly(new_bv)

                try:
                    ## TODO: Remove collect_info, make option to clear the node info afterwards
                    if collect_info:
                        p._shis, shi_info = p.get_shis(env=env, bound=bound, collect_info=collect_info, **kwargs)
                    else:
                        p._shis = p.get_shis(env=env, bound=bound, **kwargs)
                    p.get_center_inradius(env=env)
                    p.get_interior_point(env=env)
                    p.interior_point_norm
                    p.Wl2
                    if self.dim <= 6:
                        p.volume
                except ValueError as error:
                    print("\nBad SHI Calculation:", node, shi, error)
                    node._shis.remove(shi)
                    continue

                p = self.add_polyhedron(p)

                random.shuffle(p._shis)

                if depth < max_depth:
                    for new_shi in p.shis:
                        if new_shi != shi:
                            queue.append((p, new_shi, depth + 1))
                node.clean_data()
                pbar.update(n=1)
                rolling_average = (rolling_average * (pbar.n - 1) + len(p.shis)) / pbar.n
                pbar.set_postfix_str(
                    f"Depth: {depth}  Queue: {len(queue)}  Faces: {len(p.shis)}  Avg: {rolling_average:.2f} IP Norm: {p.interior_point_norm:.2f}  Finite: {p.finite}",
                    refresh=False,
                )

        except KeyboardInterrupt:
            print("\nSearch Interrupted\n")
        search_info = {
            "Search Depth": depth,
            "Avg # Facets Uncorrected": rolling_average,
            # "Avg # Facets Skipped": skipped_average,
            "Search Time": pbar.format_dict["elapsed"],
            # "Polys": poly_infos,
        }
        return search_info

    def search_calculations(self, bound, collect_info, calculation_queue, poly_queue, **kwargs):
        ## Receives Polys, returns calculated

        env = self.get_env()
        while True:
            try:
                task = calculation_queue.get()
            except ValueError:
                task = None
            if not isinstance(task, tuple):
                calculation_queue.task_done()
                break
            bv, shi, depth, node_index = task
            assert not isinstance(bv, torch.Tensor)
            p = self.bv2poly(bv)
            assert not isinstance(p.bv, torch.Tensor)

            try:
                p._halfspaces, p._W, p._b = p.get_hs_numpy(p.bv)
                p.get_center_inradius(env=env)
                p.get_interior_point(env=env)
                p._interior_point_norm = np.linalg.norm(p.interior_point).item()
                p._Wl2 = np.linalg.norm(p.W).item()

                # if self.dim <= 6:
                #     p.volume
                if p._shis is None:
                    if collect_info:
                        p._shis, shi_info = p.get_shis(env=env, bound=bound, collect_info=collect_info, **kwargs)
                    else:
                        p._shis = p.get_shis(env=env, bound=bound, **kwargs)
            except ValueError as error:
                # Bad SHI Calculation
                poly_queue.put((None, shi, str(error), node_index))
                calculation_queue.task_done()
                continue

            p.clean_data()

            random.shuffle(p._shis)
            poly_queue.put((p, shi, depth, node_index))

            calculation_queue.task_done()

    def search_producer(self, task_queue, calculation_queue, check_last=40):
        ## TODO: Move this into search_consumer, link search_consumer directly to search_calculations
        last_bvs = deque(maxlen=check_last)
        while True:
            try:
                task = task_queue.get()
            except ValueError:
                task = None
            if not isinstance(task, tuple):
                task_queue.task_done()
                print("   Exiting Producer")
                break
            node, shi, depth = task
            new_bv = node.bv.copy()
            if shi is not None:
                if node.bv[0, shi] == 0:
                    task_queue.task_done()
                    continue
                new_bv[0, shi] *= -1
                if not (new_bv in self or any((new_bv == last_bv).all() for last_bv in last_bvs)):
                    last_bvs.append(new_bv)
                    while True:
                        try:
                            calculation_queue.put((new_bv, shi, depth, self.poly2index[node]), timeout=5)
                            break
                        except Full:
                            pass
                        except ValueError:
                            break
            else:
                calculation_queue.put((new_bv, None, depth, None))
            task_queue.task_done()

    def search_consumer(self, pop, queue, queue_lock, task_queue):
        while True:
            with queue_lock:
                while len(queue) == 0:
                    queue_lock.wait()
                task = pop(queue)
            if not isinstance(task, tuple):
                # task_queue.put(task)
                print("   Exiting Consumer")
                break
            task_queue.put(task)

    def parallel_add(self, points, nworkers=None, timeout=5, bound=1e5, **kwargs):
        nworkers = nworkers or mp.cpu_count()
        print(f"Running on {nworkers} workers")

        poly_queue, calculation_queue, task_queue = (
            mp.JoinableQueue(maxsize=20),
            mp.JoinableQueue(maxsize=20),
            mp.JoinableQueue(maxsize=20),
        )
        queue_lock = mp.Condition()

        processes = []
        for i in range(nworkers):
            processes.append(
                mp.Process(
                    target=self.search_calculations,
                    args=(bound, False, calculation_queue, poly_queue),
                    kwargs=kwargs,
                )
            )
            processes[-1].start()
        producer = threading.Thread(
            target=self.search_producer,
            args=(task_queue, calculation_queue),
        )
        producer.start()
        consumer = threading.Thread(
            target=self.search_consumer,
            args=(
                lambda x: x.pop(),
                [-1] + list(map(lambda x: (self.add_point(x), None, 0), points)),
                queue_lock,
                task_queue,
            ),
        )
        consumer.start()

        pbar = tqdm(total=len(points), desc="Adding Polys", mininterval=1)

        with queue_lock:
            queue_lock.notify()

        polys = []
        try:
            while len(polys) < len(points):
                try:
                    p, shi, depth, node_index = poly_queue.get()
                except ValueError:
                    p, shi, depth, node_index = None, None, None, None
                polys.append(p)
                pbar.update(1)

                if p is None:
                    print("Error!", shi, depth, node_index)
                    continue
                p.net = self.net
                p = self.add_polyhedron(p, overwrite=True)

        except KeyboardInterrupt:
            print("\nSearch Interrupted\n")
            raise KeyboardInterrupt
        finally:
            print("Closing Buffers")
            task_queue.put(-1)
            print("Closing Calculations")
            for _ in range(nworkers):
                calculation_queue.put(-1)
            print("Joining Processes")
            for thread in [producer, consumer]:
                thread.join(timeout=5)
            for process in processes:
                process.join(timeout=5)
                if process.exitcode is None:
                    print("Terminating Process By Force")
                    process.terminate()
            print("Joining Buffers")
            for q in [poly_queue, calculation_queue, task_queue]:
                q.close()
                q.join_thread()
            print("All Processes Terminated")

        return polys

    def searcher(
        self,
        start=None,
        max_depth=float("inf"),
        max_polys=float("inf"),
        queue=None,
        pop=lambda x: x.pop(),
        bound=1e5,
        collect_info=None,
        nworkers=None,
        timeout=5,
        **kwargs,
    ):
        nworkers = nworkers or mp.cpu_count()
        if nworkers == 1:
            print("Searching Single Threaded")
            return self.searcher_single(start, max_depth, max_polys, queue, pop, bound, collect_info, **kwargs)
        ## NOTE: If nworkers>1, the traversal order may not be correct, since we are never synchronizing the workers with a barrier
        if len(self) > 0:
            raise ValueError("Decomposition already has polyhedra")
        print(f"Running on {nworkers} workers")
        if queue is None:
            queue = deque()
        if start is None:
            start = self.add_point(torch.zeros(self.net.input_shape, device=self.net.device, dtype=self.net.dtype))
        if isinstance(start, torch.Tensor):
            start = self.add_point(start)
        start.bv = start.bv.detach().cpu().numpy()
        if (start.bv == 0).any():
            raise ValueError("Start point must not be on a hyperplane")
        start._shis = start.get_shis(env=self.get_env(), bound=bound, **kwargs)
        for shi in start.shis:
            queue.append((start, shi, 1))

        rolling_average = len(start.shis)
        skipped_average = 0
        bad_shi_computations = []
        pbar = tqdm(desc="Search Progress", mininterval=5, total=max_polys if max_polys != float("inf") else None)
        pbar.update(n=1)

        poly_queue, calculation_queue, task_queue = (
            mp.JoinableQueue(maxsize=20),
            mp.JoinableQueue(maxsize=20),
            mp.JoinableQueue(maxsize=20),
        )
        queue_lock = mp.Condition()

        processes = []
        for i in range(nworkers):
            processes.append(
                mp.Process(
                    target=self.search_calculations,
                    args=(bound, collect_info, calculation_queue, poly_queue),
                    kwargs=kwargs,
                )
            )
            processes[-1].start()
        producer = threading.Thread(
            target=self.search_producer,
            args=(task_queue, calculation_queue),
        )
        producer.start()
        consumer = threading.Thread(
            target=self.search_consumer,
            args=(pop, queue, queue_lock, task_queue),
        )
        consumer.start()
        all_processes = processes + [producer, consumer]
        wasted = 0
        try:
            while True:
                try:
                    p, shi, depth, node_index = poly_queue.get(timeout=timeout)
                except Empty:
                    if len(queue) == 0:
                        task_queue.join()
                        calculation_queue.join()
                        print("Joined Queues", flush=True)
                        try:
                            p, shi, depth, node_index = poly_queue.get(block=False)
                        except Empty:
                            break
                    else:
                        continue
                node = self.index2poly[node_index]
                if p is None:
                    bad_shi_computations.append((node, shi, depth))
                    node._shis.remove(shi)
                    continue
                # p.bv = torch.from_numpy(p.bv).to(self.net.device, self.net.dtype)
                assert not isinstance(p.bv, torch.Tensor)

                if p in self:
                    wasted += 1
                    continue

                with queue_lock:
                    if depth < max_depth:
                        for new_shi in p.shis:
                            if new_shi != shi and len(self) < max_polys:
                                queue.append((p, new_shi, depth + 1))
                                queue_lock.notify()

                p.net = self.net

                p = self.add_polyhedron(p)

                pbar.update(n=len(self) - pbar.n)
                rolling_average = (rolling_average * (pbar.n - 1) + len(p.shis)) / pbar.n
                pbar.set_postfix_str(
                    f"Depth: {depth}  Queue: {len(queue)}  Faces: {len(p._shis)}  Avg: {rolling_average:.2f} IP Norm: {p._interior_point_norm:.2f}  Finite: {p._finite}  Wasted: {wasted}  Mistakes: {len(bad_shi_computations)} Queue Sizes: {(poly_queue.qsize(), calculation_queue.qsize(), task_queue.qsize())}",
                    refresh=False,
                )

                exited_processes = [process for process in all_processes if not process.is_alive()]
                if len(exited_processes) > 1:
                    print("# Processes:", len(processes) - len(exited_processes))
                    raise ValueError("Processes Exited")

                if len(self) >= max_polys:
                    break

        except KeyboardInterrupt:
            print("\nSearch Interrupted\n")
            raise KeyboardInterrupt
        finally:
            print("Closing Queue")
            with queue_lock:
                queue.appendleft(-1)
                queue.append(-1)
                queue_lock.notify()
            try:
                while True:
                    task_queue.get(block=False)
            except Empty:
                pass
            print("Closing Calculations Signal")
            task_queue.put(-1)
            print("Closing Buffers")
            for q in [poly_queue, calculation_queue, task_queue]:
                q.close()
                q.join_thread()
            print("Joining Processes")
            for thread in [producer, consumer]:
                thread.join(timeout=5)
            print("Terminating Processes")
            pbar.close()
            for process in tqdm(processes, desc="Joining Processes"):
                process.terminate()
            print("All Processes Terminated")

        search_info = {
            "Search Depth": depth,
            "Avg # Facets Uncorrected": rolling_average,
            "Avg # Facets Skipped": skipped_average,
            "Search Time": pbar.format_dict["elapsed"],
            "Bad SHI Computations": bad_shi_computations,
        }
        print("Completed Search")
        return search_info

    def bfs(self, **kwargs):
        return self.searcher(pop=lambda x: x.popleft(), **kwargs)

    def dfs(self, **kwargs):
        return self.searcher(pop=lambda x: x.pop(), **kwargs)

    def escape_search(self, **kwargs):
        return self.searcher(
            queue=list(),  ##TODO: Maxheap
            pop=lambda x: x.pop(max(range(len(x)), key=lambda y: x[y][0].interior_point_norm)),
            **kwargs,
        )

    def random_walk(self, **kwargs):
        return self.searcher(
            queue=list(),
            pop=lambda x: x.pop(random.randrange(0, len(x) - 1)),
            **kwargs,
        )

    def get_poly_attrs(self, attrs):
        # return {poly: {attr: getattr(poly, attr) for attr in attrs} for poly in self}
        return {attr: [getattr(poly, attr) for poly in self] for attr in attrs}

    def get_dual_graph(self, relabel=False):
        G = nx.Graph()
        for poly in self:
            G.add_node(poly, label=str(poly))
        for poly in tqdm(self, desc="Creating Dual Graph"):
            for shi in poly.shis:
                bv = poly.bv
                bv[0, shi] *= -1
                if bv in self:
                    G.add_edge(poly, self[bv], shi=shi)
                bv[0, shi] *= -1
        if relabel:
            G = nx.relabel_nodes(G, self.poly2index)
        return G

    def plot_dual_graph(self):
        G = self.get_dual_graph()
        for node in G.nodes:
            G.nodes[node]["label"] = str(node)
            G.nodes[node]["title"] = str(node)
        for edge in G.edges:
            G.edges[edge]["label"] = str(G.edges[edge]["shi"])
            G.edges[edge]["title"] = str(G.edges[edge]["shi"])
        return nx.relabel_nodes(G, self.poly2index)

    def recover_from_dual_graph(self, G, initial_bv, source=0):
        G = G.copy()
        initial_p = self.add_bv(initial_bv)
        G.nodes[source]["poly"] = initial_p
        for edge in tqdm(nx.edge_bfs(G, source=0), desc="Recovering Polyhedra", total=G.number_of_edges()):
            poly1, shi = G.nodes[edge[0]]["poly"], G.edges[edge]["shi"]
            poly2_bv = poly1.bv.clone()
            assert poly2_bv[0, shi] != 0
            poly2_bv[0, shi] *= -1
            poly2 = self.add_bv(poly2_bv)

            G.nodes[edge[1]]["poly"] = poly2

        for node in G:
            self[G.nodes[node]["poly"]]._shis = [G.edges[edge]["shi"] for edge in G.edges(node)]

        return G

    def plot(self, show_points=False, label_regions=False, color=None, highlight_regions=None, **kwargs):
        fig = go.Figure()
        polys = list(self)
        if color == "Wl2":
            colors = get_colors([poly.Wl2 for poly in polys])
        else:
            color_scheme = px.colors.qualitative.Plotly
            coloring = nx.algorithms.coloring.equitable_color(self.get_dual_graph(), len(color_scheme))
            colors = [color_scheme[coloring[i]] for i in polys]
        for c, poly in tqdm(zip(colors, polys), desc="Plotting Polyhedra", total=len(polys)):
            if (highlight_regions is not None) and ((poly in highlight_regions) or (str(poly) in highlight_regions)):
                c = "red"
            p_plot = poly.plot2d(
                name=f"{poly}",
                fillcolor=c,
                line_color=c,
                mode="lines",  ## Comment out to mouse over intersections
                **kwargs,
            )
            if p_plot is not None:
                fig.add_trace(p_plot)
            if label_regions and poly.center is not None:
                fig.add_trace(
                    go.Scatter(x=[poly.center[0]], y=[poly.center[1]], mode="text", text=str(poly), showlegend=False)
                )
        if show_points:
            x = np.array(list(self.points.keys()))
            fig.add_trace(
                go.Scatter(
                    x=x[:, 0],
                    y=x[:, 1],
                    mode="markers",
                    marker=dict(color="black"),
                )
            )
        maxcoord = np.median([np.max(np.abs(p.interior_point)) for p in self if p.finite]) * 1.1
        fig.update_layout(
            showlegend=True,
            # xaxis = dict(visible=False),
            # yaxis = dict(visible=False),
            plot_bgcolor="white",
            xaxis=dict(range=(-maxcoord, maxcoord)),
            yaxis=dict(range=(-maxcoord, maxcoord)),
        )
        return fig

    def plot3d(
        self,
        show_points=False,
        label_regions=False,
        color=None,
        highlight_regions=None,
        show_axes=False,
        project=True,
        **kwargs,
    ):
        fig = go.Figure()
        polys = list(self)
        if color == "Wl2":
            colors = get_colors([poly.Wl2 for poly in polys])
        else:
            color_scheme = px.colors.qualitative.Plotly
            coloring = nx.algorithms.coloring.equitable_color(self.get_dual_graph(), len(color_scheme))
            colors = [color_scheme[coloring[i]] for i in polys]
        outlines, meshes = [], []
        for c, poly in tqdm(zip(colors, polys), desc="Plotting Polyhedra", total=len(polys)):
            if (highlight_regions is not None) and ((poly in highlight_regions) or (str(poly) in highlight_regions)):
                c = "red"
            p_plot = poly.plot3d(
                name=f"{poly}",
                color=c,
                # outlinecolor="black",
                **kwargs,
            )
            if p_plot is not None:
                if isinstance(p_plot, dict):
                    if "mesh" in p_plot:
                        meshes.append(p_plot["mesh"])
                    if "outline" in p_plot:
                        outlines.append(p_plot["outline"])
                else:
                    fig.add_trace(p_plot)
            if project is not None:
                p_plot = poly.plot3d(
                    name=f"{poly}",
                    color=c,
                    project=project,
                    **kwargs,
                )
                if p_plot is not None:
                    if isinstance(p_plot, dict):
                        if "mesh" in p_plot:
                            meshes.append(p_plot["mesh"])
                        if "outline" in p_plot:
                            outlines.append(p_plot["outline"])
                    else:
                        fig.add_trace(p_plot)
            if label_regions and poly.center is not None:
                fig.add_trace(
                    go.Scatter3d(
                        x=[poly.center[0]],
                        y=[poly.center[1]],
                        z=[
                            self.net(torch.tensor(poly.center, device=self.net.device, dtype=self.net.dtype).T)
                            .detach()
                            .cpu()
                            .numpy()
                            .squeeze()
                            .flatten()[:, 0]
                        ],
                        mode="text",
                        text=str(poly),
                        showlegend=False,
                    )
                )
        for outline in outlines:
            fig.add_trace(outline)
        for mesh in meshes:
            fig.add_trace(mesh)
        if show_points:
            x = np.array(list(self.points.keys()))
            fig.add_trace(
                go.Scatter3d(
                    x=x[:, 0],
                    y=x[:, 1],
                    z=[
                        self.net(torch.tensor(x, device=self.net.device, dtype=self.net.dtype).T)
                        .detach()
                        .cpu()
                        .numpy()[:, 0]
                    ],
                    mode="markers",
                    marker=dict(color="black"),
                )
            )
        maxcoord = np.median([np.max(np.abs(p.interior_point)) for p in self if p.finite]) * 1.1
        fig.update_layout(
            scene=dict(
                xaxis=dict(range=(-maxcoord, maxcoord), visible=show_axes),
                yaxis=dict(range=(-maxcoord, maxcoord), visible=show_axes),
                zaxis=dict(visible=show_axes),
            ),
        )
        return fig
