#!/usr/bin/env python3

"""
Program module providing the progam class with cycle breaking methods and clark completion(s).
"""

from platform import node
import networkx as nx
import os
import logging
import numpy as np
import time
import math


from aspmc.parsing.clingoparser.clingoext import ClingoRule

from aspmc.graph.hypergraph import Hypergraph
import aspmc.graph.treedecomposition as treedecomposition

from aspmc.compile.cnf import CNF

from aspmc.parsing.clingoparser.clingoext import Control
import aspmc.programs.backdoor as backdoor
import aspmc.programs.grounder as grounder

import subprocess

import inspect 

src_path = os.path.abspath(os.path.realpath(inspect.getfile(inspect.currentframe())))
src_path = os.path.realpath(os.path.join(src_path, '../../external'))

import aspmc.config as config

from aspmc.programs.naming import *

logger = logging.getLogger("aspmc")

class UnsupportedException(Exception):
    '''raise this when the program relies on features that are not supported yet'''

class Rule(object):
    """A class for rules.

    Implements a custom `__repr__` method.

    Args:        
        head (:obj:`list`): The list of head atoms as dimacs literals. May be empty.
        body (:obj:`list`): The list of body atoms as dimacs literals. May be empty.

    Attributes:
        head (:obj:`list`): The list of head atoms as dimacs literals. May be empty.
        body (:obj:`list`): The list of body atoms as dimacs literals. May be empty.
    """
    def __init__(self, head, body):
        self.head = head
        self.body = body

    def __hash__(self):
        return hash((tuple(self.head), tuple(self.body)))

    def __eq__(self, other):
        if not isinstance(other, type(self)): 
            return NotImplemented
        return self.head == other.head and self.body == other.body

    def __repr__(self):
        return "; ".join([str(a) for a in self.head]) + ":- " + ", ".join([ ("not " if b < 0 else "") + str(abs(b)) for b in self.body]) 

class Program(object):
    """A base class for programs. No weights, no queries this is only for getting a cnf representation of the program.

    Allowed features are: 

    * facts
    * normal rules
    * unconditional choice constraints
    * constraints

    The cnf is generated by first doing Tp-Unfolding and then applying a Clark completion to the resulting tight program. 

    Args:
        program_str (:obj:`string`): A string containing a part of the program in ASP syntax. 
            May be the empty string.
        program_files (:obj:`list`): A list of string that are paths to files which contain programs in ASP syntax 
            that should be included. May be an empty list.
        clingo_control (:obj:`Control`): A clingo control object that contains parts of the program. 
            Must already be ground if no non-empty `program_str` or `program_files` are given.
    """
    def __init__(self, clingo_control = Control(), program_str = "", program_files = [], smodels = False):
        # the variable counter
        self._max = 0
        self._nameMap = {}
        # store the clauses here
        self._cnf = CNF()
        # remember which variables are guesses, which are derived, and which are copies of derived atoms.
        self._guess = set()
        self._deriv = set()
        self._copies = {}
        # remember possible auxilliary variables that clingo introduced
        self._auxilliary = set()
        # the list containing all the rules (except guesses)
        self._program = []
        # remember which atoms have to satisfy an exactly one of constraint
        self._exactlyOneOf = set()
        # the tree decomposition of the program
        self._td = None
        if not smodels:
            if (len(program_str) > 0 or len(program_files) > 0):
                grounder.ground(clingo_control, program_str = program_str, program_files = program_files)
            self._normalize(clingo_control)
        else:
            if (len(program_str) > 0 and len(program_files) > 0) or len(program_files) > 1:
                print(program_files, program_str)
                raise UnsupportedException("When instantiating a program from smodels format only one file or a program string may be given.")
            self._parse_smodels(program_str = program_str, program_files = program_files)

    def _parse_smodels(self, program_str = "", program_files = []):
        if len(program_str) > 0:
            lines = program_str.split('\n')
        else:
            with open(program_files[0], 'r') as in_file:
                lines = [ line[:-1] for line in in_file.readlines() ]

        for i in range(len(lines)):
            line = [ int(v) for v in lines[i].split(' ') ]
            if line[0] == 0:
                break
            elif line[0] == 1:
                head = [ line[1] ]
                nr_lit = line[2]
                nr_neg = line[3]
                body = line[4:4+nr_lit]
                self._max = max(head[0], self._max, max(body, default=0))
                for i in range(nr_neg):
                    body[i] *= -1
                self._program.append(Rule(head = head, body = body))
                self._deriv.add(head[0])
            elif line[0] == 3:
                if line[1] > 1:
                    raise UnsupportedException("Currently no choice rules with more than one atom in the head are supported.")
                elif line[1] > 0 and line[line[1] + 1] > 0:
                    raise UnsupportedException("Currently conditional choice rules are not supported.")
                if line[1] == 1:
                    self._guess.add(line[2])
                    self._max = max(self._max(line[2]))
            else:
                raise UnsupportedException(f"Unsupported rule type {line[0]}.")

        assert(len(self._deriv.intersection(self._guess)) == 0)
        ctr = i
        for i in range(ctr + 1, len(lines)):
            line = lines[i].split(' ')
            if line[0] == '0':
                break
            self._nameMap[int(line[0])] = line[1]
            self._max = max(self._max, int(line[0]))
        ctr = i
        assert(lines[ctr + 1] == "B+")
        for i in range(ctr + 2, len(lines)):
            if lines[i] == '0':
                break
            self._program.append(Rule(head = [], body = [ -int(lines[i]) ]))
        ctr = i
        assert(lines[ctr + 1] == "B-")
        for i in range(ctr + 2, len(lines)):
            if lines[i] == '0':
                break
            self._program.append(Rule(head = [], body = [ int(lines[i]) ]))
        ctr = i
        if lines[ctr + 1] == "E":
            for i in range(ctr + 2, len(lines)):
                if lines[i] == '0':
                    break
                self._guess.add(int(lines[i]))
                
        self._deriv = set(range(1, self._max + 1))
        self._deriv.difference_update(self._guess)
        for v in self._deriv:
            if v not in self._nameMap:
                self._nameMap[v] = f"projected_away({v})"

    def _remove_tautologies(self, clingo_control):
        tmp = []
        for o in clingo_control.ground_program.objects:
            if isinstance(o, ClingoRule) and set(o.head).intersection(set(o.body)) == set():
                tmp.append(o)
        return tmp

    def _normalize(self, clingo_control):
        program = self._remove_tautologies(clingo_control)
        _atomToVertex = {} # the tree decomposition solver wants succinct numbering of vertices / no holes

        symbol_map = {}
        for sym in clingo_control.symbolic_atoms:
            symbol_map[sym.literal] = str(sym.symbol)
        for o in program:
            if isinstance(o, ClingoRule):
                if len(o.head) > 1 and o.choice:
                    raise UnsupportedException("Currently no choice rules with more than one atom in the head are supported.")
                if len(o.body) > 0 and o.choice:
                    raise UnsupportedException("Currently conditional choice rules are not supported.")
                o.atoms = set(o.head)
                o.atoms.update(tuple(map(abs, o.body)))
                # if we have the falsum rule we want to replace it with two rules
                if not o.atoms:
                    a = self._new_var("unsat")
                    orig_max = max( v for v in symbol_map.keys() ) + 1
                    _atomToVertex[orig_max] = a
                    o.atoms.add(orig_max)
                    o.head = [orig_max]
                    o.body = [-orig_max]
                    o.choice = False
                self._program.append(o)
                for a in o.atoms.difference(_atomToVertex):	# add mapping for atom not yet mapped
                    if a in symbol_map:
                        _atomToVertex[a] = self._new_var(symbol_map[a])
                    else:
                        aux_var = self._new_var(f"projected_away({a})")
                        _atomToVertex[a] = aux_var
                        self._auxilliary.add(aux_var)

        trans_prog = set()
        self._deriv = set(range(1,self._max + 1))
        for r in self._program:
            if r.choice: 
                guess = [ _atomToVertex[r.head[0]] ]
                self._deriv.difference_update(guess)
                self._guess.update(guess)
            elif len(r.head) > 1:
                guess = [ _atomToVertex[x] for x in r.head ]
                self._deriv.difference_update(guess)
                self._guess.update(guess)
                self._exactlyOneOf.add(frozenset(guess))
            else:
                head = list(map(lambda x: _atomToVertex[x], r.head))
                body = list(map(lambda x: _atomToVertex[abs(x)]*(1 if x > 0 else -1), r.body))
                trans_prog.add(Rule(head,body))
        self._program = trans_prog

    def _new_var(self, name):
        self._max += 1
        self._nameMap[self._max] = name if name != "" else str(self._max)
        return self._max

    def _copy_var(self, var):
        if "(" in self._nameMap[var]:
            idx = self._nameMap[var].index("(")
            inputs = self._nameMap[var][idx:]
        else:
            inputs = ""
        if "_copy_" in self._nameMap[var]:
            idx = self._nameMap[var].index("_copy_")
            pred = self._nameMap[var][:idx]
        else:
            pred = self._nameMap[var]
            if "(" in pred:
                idx = pred.index("(")
                pred = pred[:idx]
            if pred+inputs not in self._copies:
                self._copies[pred+inputs] = [var]
        cnt = len(self._copies[pred+inputs])
        name = pred + "_copy_" + str(cnt) + inputs
        nv = self._new_var(name)
        self._copies[pred+inputs].append(nv)
        return nv

    def _internal_name(self, var):
        return self._nameMap[var]
    
    def _external_name(self, var):
        name = self._nameMap[var]
        # replace internal names with external names so we can parse programs that we print without errors
        for (internal, external) in conversions.items():
            name = name.replace(internal, external)
        return name

    def _computeComponents(self):
        self.dep = nx.DiGraph()
        self.dep.add_nodes_from(range(1, self._max + 1))
        for r in self._program:
            for a in r.head:
                for b in r.body:
                    if b > 0:
                        self.dep.add_edge(b, a)
        comp = nx.algorithms.strongly_connected_components(self.dep)
        self._components = list(comp)
        self._condensation = nx.algorithms.condensation(self.dep, self._components)
        """
        import matplotlib.pyplot as plt
        from networkx.drawing.nx_pydot import graphviz_layout
        labels = { node : str(node) for node in range(1, self._max + 1) }
        node_to_comp = {}
        for comp in self._components:
            for v in comp:
                node_to_comp[v] = comp
        red_edges = [ (v,u) for v,u in self.dep.edges() if node_to_comp[v] == node_to_comp[u] ]
        edge_colours = ['black' if not edge in red_edges else 'red' for edge in self.dep.edges()]
        black_edges = [ (v,u) for v,u in self.dep.edges() if node_to_comp[v] != node_to_comp[u] ]
        pos = graphviz_layout(self.dep, prog="dot")
        nx.draw(self.dep, pos)
        nx.draw_networkx_labels(self.dep, pos, labels)
        nx.draw_networkx_edges(self.dep, pos, edgelist=red_edges, edge_color='r', arrows=True)
        nx.draw_networkx_edges(self.dep, pos, edgelist=black_edges, arrows=True)
        plt.axis("off")
        plt.show()
        """

    def treeprocess(self):
        """Applies tree processing to the program. 
        
        This means that if there is a part of the dependency graph that is a tree and
        only has one connection to the rest of the dependency graph, then it will be processed.

        Results in one copy for each atom in the tree.

        Returns:
            None        
        """
        ins = {}
        outs = {}
        for a in self._deriv:
            ins[a] = set()
            outs[a] = set()

        for a in self._guess:
            ins[a] = set()
            outs[a] = set()

        for r in self._program:
            for a in r.head:
                ins[a].add(r)
            for b in r.body:
                if b > 0:
                    outs[b].add(r)
        ts = nx.topological_sort(self._condensation)
        ancs = {}
        decs = {}
        for t in ts:
            comp = self._condensation.nodes[t]["members"]
            for v in comp:
                ancs[v] = set([vp[0] for vp in self.dep.in_edges(nbunch=v) if vp[0] in comp])
                decs[v] = set([vp[1] for vp in self.dep.out_edges(nbunch=v) if vp[1] in comp])
        q = set([v for v in ancs.keys() if len(ancs[v]) == 1 and len(decs[v]) == 1 and list(ancs[v])[0] == list(decs[v])[0]])
        while not len(q) == 0:
            old_v = q.pop()
            if len(ancs[old_v]) == 0:
                continue
            new_v = self._copy_var(old_v)
            self._deriv.add(new_v)
            ins[new_v] = set()
            outs[new_v] = set()
            anc = ancs[old_v].pop()
            ancs[anc].remove(old_v)
            decs[anc].remove(old_v)
            if len(ancs[anc]) == 1 and len(decs[anc]) == 1 and list(ancs[anc])[0] == list(decs[anc])[0]:
                q.add(anc)

            # this contains all rules that do not use anc to derive v
            to_rem = ins[old_v].difference(outs[anc])
            # this contains all rules that use anc to derive v
            # we just keep them as they are
            ins[old_v] = ins[old_v].intersection(outs[anc])
            # any rule that does not use anc to derive v can now only derive new_v
            for r in to_rem:
                head = [b if b != old_v else new_v for b in r.head]
                new_r = Rule(head,r.body)
                ins[new_v].add(new_r)
                for b in r.body:
                    if b > 0:
                        outs[b].remove(r)
                        outs[b].add(new_r)

            # this contains all rules that use v and derive anc
            to_rem = outs[old_v].intersection(ins[anc])
            # this contains all rules that use v and do not derive anc
            # we just keep them as they are
            outs[old_v] = outs[old_v].difference(ins[anc])
            # any rule that uses v to derive anc must use new_v
            for r in to_rem:
                body = [ (b if b != old_v else new_v) for b in r.body]
                new_r = Rule(r.head,body)
                for b in r.head:
                    ins[b].remove(r)
                    ins[b].add(new_r)
                for b in r.body:
                    if b > 0:
                        if b != old_v:
                            outs[abs(b)].remove(r)
                            outs[abs(b)].add(new_r)
                        else:
                            outs[new_v].add(new_r)
            new_r = Rule([old_v], [new_v])
            ins[old_v].add(new_r)
            outs[new_v].add(new_r)
        # only keep the constraints
        self._program = [r for r in self._program if len(r.head) == 0]
        # add all the other rules
        for a in ins.keys():
            self._program.extend(ins[a])


    def _write_scc(self, comp):
        res = ""
        for v in comp:
            res += f"p({v}).\n"
            ancs = set([vp[0] for vp in self.dep.in_edges(nbunch=v) if vp[0] in comp])
            for vp in ancs:
                res += f"edge({vp},{v}).\n"
        return res

    def _compute_backdoor_clingo(self, idx, timeout = 30.0):
        comp = self._condensation.nodes[idx]["members"]
        program_str = "\n".join(f"{{abs({v})}}." for v in comp) + "\n"
        program_str += ":~ abs(X). [1,X]\n"
        program_str += "ok(X) :- abs(X).\n"
        for v in comp:
            ancs = list(set([vp[0] for vp in self.dep.in_edges(nbunch=v) if vp[0] in comp] + [vp[0] for vp in self.dep.out_edges(nbunch=v) if vp[0] in comp]))
            for i in range(len(ancs)):
                program_str += f"ok({v}) :- {','.join(f'ok({vp})' for vp in ancs[:i] + ancs[i+1:])}.\n"
        program_str += "\n".join(f":- not ok({v})." for v in comp) + "\n"
        program_str += "#show abs/1."
        c = backdoor.ClingoControl(program_str)
        res = c.get_backdoor(None, timeout = timeout)[2][0]
        return res

    def _compute_backdoor_fvs(self, idx, timeout = 30.0, approximate = False):
        comp = self._condensation.nodes[idx]["members"]
        edges = {}
        for v in comp:
            ancs = set([vp[0] for vp in self.dep.in_edges(nbunch=v) if vp[0] in comp])
            for vp in ancs:
                if vp > v:
                    if v not in edges:
                        edges[v] = set()
                    edges[v].add(vp)
                else:
                    if vp not in edges:
                        edges[vp] = set()
                    edges[vp].add(v)
        graph = ""
        for v in edges.keys():
            for vp in edges[v]:
                graph += f"{v} {vp}\n"
        if not approximate:
            q = subprocess.Popen([os.path.join(src_path, "fvs/src/build/FeedbackVertexSet")], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
            output, err = q.communicate(input=graph.encode(), timeout = float(config.config["backdoort"]))
        else:
            q = subprocess.Popen([os.path.join(src_path, "fvs/src/build/FeedbackVertexSet"), "Appx"], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
            output, err = q.communicate(input=graph.encode())
        res = [ int(v) for v in output.decode().split()[1:] ]
        return res

    def _compute_backdoor(self, idx):
        start = time.time()
        comp = self._condensation.nodes[idx]["members"]
        try:
            if config.config["backdoors"] == "fvs":
                res = self._compute_backdoor_fvs(idx, timeout = float(config.config["backdoort"]), approximate = False)
            elif config.config["backdoors"] == "clingo":
                res = self._compute_backdoor_clingo(idx, timeout = float(config.config["backdoort"]))
        except (subprocess.TimeoutExpired,IndexError):
            logger.warning(f"Optimal backdoor computation failed, switching to approximation.")
            res = self._compute_backdoor_fvs(idx, timeout = float(config.config["backdoort"]), approximate = True)
        if False:
            import matplotlib.pyplot as plt
            from networkx.drawing.nx_pydot import graphviz_layout
            labels = { node : self._external_name(node) for node in comp }
            local_dep = self.dep.subgraph(comp)
            pos = graphviz_layout(local_dep, prog="neato")
            values = [ 1.0 if node in res else 0.0 for node in local_dep.nodes()]
            nx.draw(local_dep, pos, cmap=plt.get_cmap('viridis'), node_color=values)
            nx.draw_networkx_labels(local_dep, pos, labels)
            plt.tight_layout()
            plt.axis("off")
            plt.show()
        logger.debug("backdoor comp: " + str(len(comp)))
        logger.debug("backdoor res: " + str(len(res)))
        logger.debug(f"backdoor time: {time.time() - start}")
        return res

    def _backdoor_process(self, comp, backdoor):
        comp = set(comp)
        backdoor = set(backdoor)

        toRemove = set()
        ins = {}
        for a in comp:
            ins[a] = set()
        for r in self._program:
            for a in r.head:
                if a in comp:
                    ins[a].add(r)
                    toRemove.add(r)

        copies = {}
        for a in comp:
            copies[a] = {}
            copies[a][len(backdoor)] = a

        def getAtom(atom, i):
            # negated atoms are kept as they are
            if atom < 0:
                return atom
            # atoms that are not from this component are input atoms and should stay the same
            if atom not in comp:
                return atom
            if i < 0:
                print("this should not happen")
                exit(-1)
            if atom not in copies:
                print("this should not happen")
                exit(-1)
            if i not in copies[atom]:
                copies[atom][i] = self._copy_var(atom)
                self._deriv.add(copies[atom][i])
            return copies[atom][i]

        toAdd = set()
        for a in backdoor:
            for i in range(1,len(backdoor)+1):
                head = [getAtom(a, i)]
                for r in ins[a]:
                    if i == 1:
                        # in the first iteration we do not add rules that use atoms from the backdoor
                        add = True
                        for x in r.body:
                            if x > 0 and x in backdoor:
                                add = False
                    else:
                        # in all but the first iteration we only use rules that use at least one atom from the SCC we are in
                        add = False
                        for x in r.body:
                            if x > 0 and x in comp:
                                add = True
                    if add:
                        body = [getAtom(x, i - 1) for x in r.body]
                        new_rule = Rule(head, body)
                        toAdd.add(new_rule)
                if i > 1:
                    toAdd.add(Rule(head, [getAtom(a, i - 1)]))

        for a in comp.difference(backdoor):
            for i in range(len(backdoor)+1):
                head = [getAtom(a, i)]
                for r in ins[a]:
                    if i == 0:
                        # in the first iteration we only add rules that only use atoms from outside 
                        add = True
                        for x in r.body:
                            if x > 0 and x in backdoor:
                                add = False
                    else:
                        # in all other iterations we only use rules that use at least one atom from the SCC we are in
                        add = False
                        for x in r.body:
                            if x >  0 and x in comp:
                                add = True
                    if add:
                        body = [getAtom(x, i) for x in r.body]
                        new_rule = Rule(head, body)
                        toAdd.add(new_rule)
                if i > 0:
                    toAdd.add(Rule(head, [getAtom(a, i - 1)]))

        self._program = [r for r in self._program if r not in toRemove]
        self._program += list(toAdd)
        
        
    def tpUnfold(self):
        """Applies Tp-Unfolding to the program. 
        
        Applies a variant to be precise by first doing treeprocessing
        and then Tp-Unfolding as this can be a bit better.

        Returns:
            None        
        """
        self._computeComponents()
        self.treeprocess()
        self._computeComponents()
        ts = nx.topological_sort(self._condensation)
        for t in ts:
            comp = self._condensation.nodes[t]["members"]
            if len(comp) > 1:
                backdoor = self._compute_backdoor(t)
                # if the backdoor needs more than half the atoms it is better if we use all the atoms as the backdoor
                # this is because treeprocessing has another factor*2 and backdoor*2 > comp
                backdoor = comp if len(backdoor) > len(comp)/2 else backdoor
                self._backdoor_process(comp, backdoor)
        self._computeComponents()
        self.treeprocess()
        self._computeComponents()
        ts = nx.topological_sort(self._condensation)
        for t in ts:
            comp = self._condensation.nodes[t]["members"]
            if len(comp) > 1:
                logger.error("Cycle breaking failed: the dependency graph still has a non-trivial SCC")
                exit(-1)

    def binary_cycle_breaking(self, local = False):
        if self._td is None and local:
            self._decomposeGraph(solver = config.config["decos"], timeout = config.config["decot"])
        self._computeComponents()
        # here we remember the new rules
        new_rules = []
        # maps a node t to a set of rules that need to be considered in t
        # it actually suffices if every rule is considered only once in the entire td..
        rules = {}
        # a dictionary that remembers the variables that we use for the lessThan predicate between atoms
        lessThan = {}
        # remember which atoms we used for the bits 
        bits = {}
        if local:
            for t in self._td.bag_iter():
                bits[t] = {}
                tmp_vert = t.vertices.copy()
                while not len(tmp_vert) == 0:
                    cur = tmp_vert.pop()
                    if cur in self._condensation.graph["mapping"]:
                        # otherwise cur does not occur positively
                        comp_id = self._condensation.graph["mapping"][cur]
                        comp = self._condensation.nodes[comp_id]["members"]
                        tmp_vert.difference_update(comp)
                        both = t.vertices.intersection(comp)
                        count = math.ceil(math.log(len(both),2))
                        for a in both:
                            bits[t][a] = [ self._new_var(f"bin_counter({self._external_name(a)},{t.idx},{i})") for i in range(count) ]
                            self._guess.update(bits[t][a])
        else:
            bits[None] = {}
            for comp in self._components:
                count = math.ceil(math.log(len(comp),2))
                for a in comp:
                    bits[None][a] = [ self._new_var(f"bin_counter({self._external_name(a)},none,{i})") for i in range(count)]
                    self._guess.update(bits[None][a])

        if local:
            # temporary copy of the program, will be empty after the first pass
            program = list(self._program)
            # first td pass: determine rules 
            for t in self._td.bag_iter():
                # take the rules we need and remove them
                rules[t] = [r for r in program if set(r.head + [ abs(x) for x in r.body ]).issubset(t.vertices)]
                program = [r for r in program if not set(r.head + [ abs(x) for x in r.body ]).issubset(t.vertices)]
        else: 
            rules[None] = list(self._program)

        # a subroutine to generate x < x'
        def generateLessThan(x, xp, local = False, node = None):
            assert(node is not None or not local)
            # setup and check if this has already been handled
            if not (x,xp,node) in lessThan:
                lessThan[(x,xp,node)] = self._new_var(f"less_than({self._external_name(x)},{node.idx if node is not None else 'none'},{self._external_name(xp)})")
                self._deriv.add(lessThan[(x,xp,node)])
            else:
                return lessThan[(x,xp,node)]

            # check if x and xp are in differens components
            xs_comp = self._condensation.graph["mapping"][x]
            xps_comp = self._condensation.graph["mapping"][xp]
            if xs_comp != xps_comp:
                # determine which is in the higher component
                if nx.algorithms.shortest_paths.generic.has_path(self._condensation, xs_comp, xps_comp):
                    new_rules.append(Rule([lessThan[(x,xp,node)]],[]))
                elif nx.algorithms.shortest_paths.generic.has_path(self._condensation, xps_comp, xs_comp):
                    new_rules.append(Rule([],[lessThan[(x,xp,node)]]))
                else: # there is no connection between these at all. should not occur.
                    logger.error("No connection between nodes that need to be connected!")
                    exit(1)
                return lessThan[(x,xp,node)]

            # x and xp are in the same component 
            # obtain the bits and their number
            count = len(bits[node][x])
            x_bits = bits[node][x]
            xp_bits = bits[node][xp]

            # remember all the disjuncts here
            head = [ lessThan[(x,xp,node)] ]
            for i in range(count):
                body = [ xp_bits[i], -x_bits[i] ]
                for j in range(i + 1, count):
                    andVar = self._new_var(f"less_than_bin({self._external_name(x)},{node.idx if node is not None else 'none'},{self._external_name(xp)},{i},{j})")
                    self._deriv.add(andVar)
                    body.append(-andVar)
                    new_rules.append(Rule([andVar], [-xp_bits[j], x_bits[j]]))
                new_rules.append(Rule(head, body))
            return lessThan[(x,xp,node)]

        if local:
            # second td pass: use rules to generate the reduction
            for t in self._td.bag_iter():
                # generate (2), i.e. the constraints that maintain the inequalities between nodes
                for tp in t.children:
                    relevant = tp.vertices.intersection(t.vertices)
                    rel_cp = relevant.copy()
                    while len(rel_cp) > 0:
                        cur = rel_cp.pop()
                        comp_id = self._condensation.graph["mapping"][cur]
                        comp = self._condensation.nodes[comp_id]["members"]
                        both = relevant.intersection(comp)
                        for x in both:
                            if x == cur:
                                continue
                            t_atom = generateLessThan(x, cur, local = local, node = t)
                            tp_atom = generateLessThan(x, cur, local = local, node = tp)
                            new_rules.append(Rule([],[t_atom, -tp_atom]))
                            new_rules.append(Rule([],[-t_atom, tp_atom]))
                                
                
                # add the order constraints to the rules in the current node
                for r in rules[t]:
                    if len(r.head) > 0:
                        to_add = []
                        head = r.head[0]
                        for a in r.body:
                            if a > 0:
                                to_add.append(generateLessThan(a, head, local = local, node = t))
                        new_rules.append(Rule([], [-head] + r.body))
                        r.body += to_add
                    new_rules.append(r)
        else:
            # add the order constraints to the rules in the current node
            for r in rules[None]:
                if len(r.head) > 0:
                    to_add = []
                    head = r.head[0]
                    for a in r.body:
                        if a > 0:
                            to_add.append(generateLessThan(a, head, local = local, node = None))
                    new_rules.append(Rule([], [-head] + r.body))
                    r.body += to_add
                new_rules.append(r)
        self._program = new_rules

    
     

    def less_than_cycle_breaking(self, opt = True):
        self._computeComponents()
        # here we remember the new rules
        new_rules = []
        # a dictionary that remembers the variables that we use for the lessThan predicate between atoms
        lessThan = {}

        # a subroutine to generate x < x'
        def getLessThan(x, xp):
            negated = x > xp
            if negated:
                x,xp = xp,x
            # setup and check if this has already been handled
            if not (x,xp) in lessThan:
                lessThan[(x,xp)] = self._new_var(f"less_than({x},{xp})")
                # add the atoms as guesses
                self._guess.add(lessThan[(x,xp)])
            else:
                return -lessThan[(x,xp)] if negated else lessThan[(x,xp)]

            # check if x and xp are in differens components
            xs_comp = self._condensation.graph["mapping"][x]
            xps_comp = self._condensation.graph["mapping"][xp]
            if xs_comp != xps_comp:
                # determine which is in the higher component
                if nx.algorithms.shortest_paths.generic.has_path(self._condensation, xs_comp, xps_comp):
                    new_rules.append(Rule([], [-lessThan[(x,xp)]]))
                elif nx.algorithms.shortest_paths.generic.has_path(self._condensation, xps_comp, xs_comp):
                    new_rules.append(Rule([], [lessThan[(x,xp)]]))
                else: # there is no connection between these at all. should not occur.
                    logger.error("No connection between nodes that need to be connected!")
                    exit(1)

            return -lessThan[(x,xp)] if negated else lessThan[(x,xp)]

        ts = nx.topological_sort(self._condensation)
        for t in ts:
            comp = self._condensation.nodes[t]["members"]
            if len(comp) > 1:
                # antisymmetry and connexity are true automatically
                if opt:
                    x_values = self._compute_backdoor(t)
                else:
                    x_values = comp
                # transitivity
                for x in x_values:
                    for y in comp:
                        if x == y:
                            continue
                        if opt:
                            z_values = self.dep.successors(y)
                        else:
                            z_values = comp
                        for z in z_values:
                            if y != z and x != z:
                                new_rules.append(Rule([], [getLessThan(x,y), getLessThan(y,z), -getLessThan(x,z)]))
                        
        
        # add lessThan atoms to break the cycles reduction
        for r in self._program:
            if len(r.head) > 0:
                to_add = []
                head = r.head[0]
                for a in r.body:
                    if a > 0:
                        to_add.append(getLessThan(a, head))
                new_rules.append(Rule([], [-head] + r.body))
                r.body += to_add
            new_rules.append(r)
        self._program = new_rules



    def _decomposeGraph(self, solver = "flow-cutter", timeout = "-1"):            
        self._graph = Hypergraph()
        self._graph.add_nodes_from(range(1, self._max + 1))
        for r in self._program:
            atoms = set(r.head)
            atoms.update(tuple(map(abs, r.body)))
            self._graph.add_edge(atoms)
        self._td = treedecomposition.from_hypergraph(self._graph, solver = solver, timeout = timeout)

    def clark_completion(self):
        """Applies the clark completion to the program. 

        Does not check whether the program is tight! 
        Does not use tree decomposition guidance to obtain a program of possibly smaller treewidth.
        
        Does not return anything only constructs the cnf.
        The CNF can be obtained by using `get_cnf()`.

        Returns:
            None        
        """
        self._cnf = CNF()
        self._cnf.nr_vars = self._max
        # local method to get a new auxilliary variable
        # so that the atom counter of the program does not change
        def aux_var():
            self._cnf.nr_vars += 1
            self._cnf.auxilliary.add(self._cnf.nr_vars)
            return self._cnf.nr_vars

        perAtom = {}
        for a in self._deriv:
            perAtom[a] = []

        for r in self._program:
            for a in r.head:
                perAtom[a].append(r)

        for head in self._deriv:
            ors = []
            for r in perAtom[head]:
                ors.append(aux_var())
                ands = [-x for x in r.body]
                self._cnf.clauses.append([ors[-1]] + ands)
                for at in ands:
                    self._cnf.clauses.append([-ors[-1], -at])
            self._cnf.clauses.append([-head] + [o for o in ors])
            for o in ors:
                self._cnf.clauses.append([head, -o])

        # handle the constraints
        constraints = [r for r in self._program if len(r.head) == 0]
        for r in constraints:
            self._cnf.clauses.append([-x for x in r.body])

        # handle the guesses
        for r in self._exactlyOneOf:
            # at least one
            self._cnf.clauses.append(list(r))
            # at most one
            for v in r:
                for vp in r:
                    if v < vp:
                        self._cnf.clauses.append([-v, -vp])
                            
        self._finalize_cnf()

    def td_guided_clark_completion(self):        
        """Applies the clark completion to the program. 

        Does not check whether the program is tight! 
        Use tree decomposition guidance on the "ors" to obtain a program of possibly smaller treewidth.
        
        Does not return anything only constructs the cnf.
        The CNF can be obtained by using `get_cnf()`.

        The solver to compute the tree decomposition and its timeout can be specified in 
        aspmc.config.

        Returns:
            None        
        """
        self._cnf = CNF()
        self._cnf.nr_vars = self._max
        # local method to get a new auxilliary variable
        # so that the atom counter of the program does not change
        def aux_var():
            self._cnf.nr_vars += 1
            self._cnf.auxilliary.add(self._cnf.nr_vars)
            return self._cnf.nr_vars
        
        self._decomposeGraph(solver = config.config["decos"], timeout = config.config["decot"])
        logger.info(f"Tree Decomposition #bags: {self._td.bags} unfolded treewidth: {self._td.width} #vertices: {self._td.vertices}")
        # at which td node to handle each rule
        rules = {}
        # at which td node each variable occurs last
        last = {}
        idx = 0
        td_idx = list(self._td)
        for t in self._td.bag_iter():
            for a in t.vertices:
                last[a] = idx
            t.idx = idx
            idx += 1
            rules[t] = []

        for r in self._program:
            for a in r.head:
                r.proven = aux_var()
                ands = [-x for x in r.body]
                self._cnf.clauses.append([ r.proven ] + ands)
                for at in ands:
                    self._cnf.clauses.append([ -r.proven, -at ])
            idx = min([ last[abs(b)] for b in r.body + r.head ])
            rules[self._td.get_bag(td_idx[idx])].append(r)

        # how many rules have we used and what is the last used variable
        unfinished = {}
        for t in self._td.bag_iter():
            unfinished[t] = {}
            t.vertices = set(t.vertices)
            to_handle = {}
            for a in t.vertices:
                to_handle[a] = []
            for tp in t.children:
                removed = tp.vertices.difference(t.vertices)
                for a in removed:
                    if a in self._deriv:
                        if a in unfinished[tp]:
                            final = unfinished[tp].pop(a)
                            self._cnf.clauses.append([-a, final])
                            self._cnf.clauses.append([a, -final])
                        else: 
                            self._cnf.clauses.append([-a])
                rest = tp.vertices.intersection(t.vertices)
                for a in rest:
                    if a in unfinished[tp]:
                        to_handle[a].append(unfinished[tp][a])

            # take the rules we need
            for r in rules[t]:
                for a in r.head:
                    to_handle[a].append(r.proven)

            # handle all the rules we have gathered
            for a in t.vertices:
                if len(to_handle[a]) > 1:
                    new_last = aux_var()
                    self._cnf.clauses.append([-new_last] + to_handle[a])
                    for at in to_handle[a]:
                        self._cnf.clauses.append([new_last, -at])
                    unfinished[t][a] = new_last
                elif len(to_handle[a]) == 1:
                    unfinished[t][a] = to_handle[a][0]

        for a in self._td.get_root().vertices:
            if a in self._deriv:
                if a in unfinished[self._td.get_root()]:
                    final = unfinished[self._td.get_root()].pop(a)
                    self._cnf.clauses.append([-a, final])
                    self._cnf.clauses.append([a, -final])
                else: 
                    self._cnf.clauses.append([-a])

        # handle the constraints
        constraints = [r for r in self._program if len(r.head) == 0]
        for r in constraints:
            self._cnf.clauses.append([-x for x in r.body])

       # handle the guesses
        for r in self._exactlyOneOf:
            # at least one
            self._cnf.clauses.append(list(r))
            # at most one
            for v in r:
                for vp in r:
                    if v < vp:
                        self._cnf.clauses.append([-v, -vp])

        self._finalize_cnf()


    def td_guided_both_clark_completion(self, adaptive = False, latest = True): 
        """Applies the clark completion to the program. 

        Does not check whether the program is tight! 
        Use tree decomposition guidance on both the "ands" and the "ors"
        to obtain a program of possibly smaller treewidth.
        
        Does not return anything only constructs the cnf.
        The CNF can be obtained by using `get_cnf()`.

        The solver to compute the tree decomposition and its timeout can be specified in 
        aspmc.config.

        Returns:
            None        
        """
        self._cnf = CNF()
        self._cnf.nr_vars = self._max
        # local method to get a new auxilliary variable
        # so that the atom counter of the program does not change
        def aux_var():
            self._cnf.nr_vars += 1
            self._cnf.auxilliary.add(self._cnf.nr_vars)
            return self._cnf.nr_vars

        # remember whats an and, whats an or and whats a constraint
        # also include the guesses, which guess exactly one of their inputs to be true
        OR = 0
        AND = 1
        CON = 2
        GUESS = 3
        INPUT = 4
        nodes = { a : (OR, set()) for a in self._deriv }

        exactly_one_to_var = {}
        for a in self._exactlyOneOf:
            exactly_one_to_var[a] = aux_var()
            nodes[exactly_one_to_var[a]] = (GUESS, set(a))

        for atom in self._guess:
            nodes[atom] = (INPUT, set())

        if not latest:
            seen = set()
            facts = set([ r.head[0] for r in self._program if len(r.body) == 0 ])
            for f in facts:
                self._cnf.clauses.append([f])
            seen.update(facts)
        
        if not latest:
            remaining = [ r for r in self._program if len(r.head) == 0 or r.head[0] not in facts ]
        else:
            remaining = self._program
        for r in remaining:
            r.proven = aux_var()
            if len(r.head) != 0:
                nodes[r.proven] = (AND, set(r.body))
                nodes[abs(r.head[0])][1].add(r.proven)
                if not latest:
                    seen.add(r.head[0])
            else:
                nodes[r.proven] = (CON, set([ -atom for atom in r.body ]))

        if not latest:
            # handle the atoms that do not occur in the head of any rule
            falses = [ a for a in self._deriv if a not in seen and nodes[a][0] == OR and len(nodes[a][1]) == 0 ]
            for f in falses:
                self._cnf.clauses.append([-f])

        # set up the and/or graph
        graph = nx.Graph()
        graph.add_nodes_from(range(1, self._cnf.nr_vars + 1))
        for r in self._program:
            if len(r.body) > 0:
                for atom in r.head:
                    graph.add_edge(atom, r.proven)
                    if adaptive:
                        graph.add_edge(atom + self._cnf.nr_vars, r.proven)
                for atom in r.body:
                    graph.add_edge(r.proven, abs(atom))
                    if adaptive:
                        graph.add_edge(r.proven + self._cnf.nr_vars, abs(atom))
        
        for a in self._exactlyOneOf:
            for atom in a:
                graph.add_edge(exactly_one_to_var[a], atom)
                if adaptive:
                    graph.add_edge(exactly_one_to_var[a] + self._cnf.nr_vars, atom)


        td = treedecomposition.from_graph(graph, solver = config.config["decos"], timeout = config.config["decot"])
        if adaptive:
            td.remove(set(range(self._cnf.nr_vars + 1, 2*self._cnf.nr_vars + 1)))
            td.vertices = self._cnf.nr_vars 
        logger.info(f"Tree Decomposition #bags: {td.bags} unfolded treewidth: {td.width} #vertices: {td.vertices}")

        
        if latest:
            # at which td node each variable occurs last
            last = {}
            idx = 0
            for t in td.bag_iter():
                for a in t.vertices:
                    last[a] = idx
                t.idx = idx
                idx += 1
        
        # remember per bag which nodes have which partial result
        unfinished = {}
        # handle the bags in dfs order
        for t in td.bag_iter():
            unfinished[t] = {}
            # first take care of what we got from the children
            for tp in t.children:
                for atom in unfinished[tp]:
                    if atom not in unfinished[t]:
                        unfinished[t][atom] = unfinished[tp][atom]
                    else:
                        if len(unfinished[tp][atom]) == 1:
                            first_lit = unfinished[tp][atom].pop()
                        else:
                            node_type = nodes[atom][0]
                            first_lit = aux_var()
                            if node_type == AND:
                                bigAnd = [ first_lit ] + [ -v for v in unfinished[tp][atom] ]
                                self._cnf.clauses.append(bigAnd)                  
                                for v in unfinished[tp][atom]:
                                    self._cnf.clauses.append([ -first_lit, v ])
                            elif node_type == OR or node_type == CON:
                                bigOr = [ -first_lit ] + [ v for v in unfinished[tp][atom] ]
                                self._cnf.clauses.append(bigOr)
                                for v in unfinished[tp][atom]:
                                    self._cnf.clauses.append([ first_lit, -v ])
                            elif node_type == GUESS:
                                # remember in first_lit, whether one of the unfinished atoms is true
                                bigOr = [ -first_lit ] + [ v for v in unfinished[tp][atom] ]
                                self._cnf.clauses.append(bigOr)
                                for v in unfinished[tp][atom]:
                                    self._cnf.clauses.append([ first_lit, -v ])
                                # make sure that not more than one of the unfinished atoms is true
                                for v in unfinished[tp][atom]:
                                    for vp in unfinished[tp][atom]:
                                        if v < vp:
                                            self._cnf.clauses.append([-v, -vp])

                        if len(unfinished[t][atom]) == 1:
                            second_lit = unfinished[t][atom].pop()
                        else:
                            node_type = nodes[atom][0]
                            second_lit = aux_var()
                            if node_type == AND:
                                bigAnd = [ second_lit ] + [ -v for v in unfinished[t][atom] ]
                                self._cnf.clauses.append(bigAnd)
                                for v in unfinished[t][atom]:
                                    self._cnf.clauses.append([ -second_lit, v ])
                            elif node_type == OR or node_type == CON:
                                bigOr = [ -second_lit ] + [ v for v in unfinished[t][atom] ]
                                self._cnf.clauses.append(bigOr)
                                for v in unfinished[t][atom]:
                                    self._cnf.clauses.append([ second_lit, -v ])
                            elif node_type == GUESS:
                                # remember in second_lit, whether one of the unfinished atoms is true
                                bigOr = [ -second_lit ] + [ v for v in unfinished[t][atom] ]
                                self._cnf.clauses.append(bigOr)
                                for v in unfinished[t][atom]:
                                    self._cnf.clauses.append([ second_lit, -v ])
                                # make sure that not more than one of the unfinished atoms is true
                                for v in unfinished[t][atom]:
                                    for vp in unfinished[t][atom]:
                                        if v < vp:
                                            self._cnf.clauses.append([-v, -vp])

                        unfinished[t][atom] = set([first_lit, second_lit])
                        
            if not latest:
                # then take care of the current bag
                for a in t.vertices:
                    node_type, inputs = nodes[a]
                    todo_new = set([ atom for atom in inputs if abs(atom) in t.vertices ])
                    if len(todo_new) == 0:
                        continue
                    inputs.difference_update(todo_new)
                    if a not in unfinished[t]:
                        unfinished[t][a] = todo_new
                    else:
                        if len(unfinished[t][a]) == 1:
                            first_lit = unfinished[t][a].pop()
                        else:
                            first_lit = aux_var()
                            if node_type == AND:
                                bigAnd = [ first_lit ] + [ -v for v in unfinished[t][a] ]
                                self._cnf.clauses.append(bigAnd)
                                for v in unfinished[t][a]:
                                    self._cnf.clauses.append([ -first_lit, v ])
                            elif node_type == OR or node_type == CON:
                                bigOr = [ -first_lit ] + [ v for v in unfinished[t][a] ]
                                self._cnf.clauses.append(bigOr)
                                for v in unfinished[t][a]:
                                    self._cnf.clauses.append([ first_lit, -v ])
                            elif node_type == GUESS:
                                # remember in first_lit, whether one of the unfinished atoms is true
                                bigOr = [ -first_lit ] + [ v for v in unfinished[t][a] ]
                                self._cnf.clauses.append(bigOr)
                                for v in unfinished[t][a]:
                                    self._cnf.clauses.append([ first_lit, -v ])
                                # make sure that not more than one of the unfinished atoms is true
                                for v in unfinished[t][a]:
                                    for vp in unfinished[t][a]:
                                        if v < vp:
                                            self._cnf.clauses.append([-v, -vp])

                        if len(todo_new) == 1:
                            second_lit = todo_new.pop()
                        else:
                            second_lit = aux_var()
                            if node_type == AND:
                                bigAnd = [ second_lit ] + [ -v for v in todo_new ]
                                self._cnf.clauses.append(bigAnd)
                                for v in todo_new:
                                    self._cnf.clauses.append([ -second_lit, v ])
                            elif node_type == OR or node_type == CON:
                                bigOr = [ -second_lit ] + [ v for v in todo_new ]
                                self._cnf.clauses.append(bigOr)
                                for v in todo_new:
                                    self._cnf.clauses.append([ second_lit, -v ])
                            elif node_type == GUESS:
                                # remember in second_lit, whether one of the unfinished atoms is true
                                bigOr = [ -second_lit ] + [ v for v in todo_new ]
                                self._cnf.clauses.append(bigOr)
                                for v in todo_new:
                                    self._cnf.clauses.append([ second_lit, -v ])
                                # make sure that not more than one of the unfinished atoms is true
                                for v in todo_new:
                                    for vp in todo_new:
                                        if v < vp:
                                            self._cnf.clauses.append([-v, -vp])

                        unfinished[t][a] = set([first_lit, second_lit])

                
                # check which nodes are completely done and finalize them
                new_unfinished = {}
                for a in unfinished[t]:
                    if a in t.vertices: # if the variable is still there, we keep it
                        new_unfinished[a] = unfinished[t][a]
                    else: # otherwise we finalize the node
                        node_type = nodes[a][0]
                        if node_type == AND:
                            bigAnd = [ a ] + [ -v for v in unfinished[t][a] ]
                            self._cnf.clauses.append(bigAnd)
                            for v in unfinished[t][a]:
                                self._cnf.clauses.append([ -a, v ])
                        elif node_type == OR:
                            bigOr = [ -a ] + [ v for v in unfinished[t][a] ]
                            self._cnf.clauses.append(bigOr)
                            for v in unfinished[t][a]:
                                self._cnf.clauses.append([ a, -v ])
                        elif node_type == CON:
                            bigOr = [ v for v in unfinished[t][a] ]
                            self._cnf.clauses.append(bigOr)
                            self._cnf.clauses.append([-a])
                        elif node_type == GUESS:
                            # make sure that at least one of the unfinished atoms is true
                            bigOr = [ v for v in unfinished[t][a] ]
                            self._cnf.clauses.append(bigOr)
                            # make sure that not more than one of the unfinished atoms is true
                            for v in unfinished[t][a]:
                                for vp in unfinished[t][a]:
                                    if v < vp:
                                        self._cnf.clauses.append([-v, -vp])
                            self._cnf.clauses.append([-a])

                unfinished[t] = new_unfinished
            else:
                # then take care of the current bag
                for a in t.vertices:
                    node_type, inputs = nodes[a]
                    if t.idx == last[a]:
                        if a in unfinished[t]:
                            if len(unfinished[t][a]) == 1:
                                inputs.add(unfinished[t][a].pop())
                            else:
                                new_lit = aux_var()
                                if node_type == AND:
                                    bigAnd = [ new_lit ] + [ -v for v in unfinished[t][a] ]
                                    self._cnf.clauses.append(bigAnd)
                                    for v in unfinished[t][a]:
                                        self._cnf.clauses.append([ -new_lit, v ])
                                elif node_type == OR or node_type == CON:
                                    bigOr = [ -new_lit ] + [ v for v in unfinished[t][a] ]
                                    self._cnf.clauses.append(bigOr)
                                    for v in unfinished[t][a]:
                                        self._cnf.clauses.append([ new_lit, -v ])
                                elif node_type == GUESS:
                                    # remember in new_lit, whether one of the unfinished atoms is true
                                    bigOr = [ -new_lit ] + [ v for v in unfinished[t][a] ]
                                    self._cnf.clauses.append(bigOr)
                                    for v in unfinished[t][a]:
                                        self._cnf.clauses.append([ new_lit, -v ])
                                    # make sure that not more than one of the unfinished atoms is true
                                    for v in unfinished[t][a]:
                                        for vp in unfinished[t][a]:
                                            if v < vp:
                                                self._cnf.clauses.append([-v, -vp])
                                inputs.add(new_lit)
                            del unfinished[t][a]
                        if node_type == AND:
                            bigAnd = [ a ] + [ -v for v in inputs ]
                            self._cnf.clauses.append(bigAnd)
                            for v in inputs:
                                self._cnf.clauses.append([ -a, v ])
                        elif node_type == OR:
                            bigOr = [ -a ] + [ v for v in inputs ]
                            self._cnf.clauses.append(bigOr)
                            for v in inputs:
                                self._cnf.clauses.append([ a, -v ])
                        elif node_type == CON:
                            bigOr = [ v for v in inputs ]
                            self._cnf.clauses.append(bigOr)
                            self._cnf.clauses.append([-a])
                        elif node_type == GUESS:
                            # make sure that at least one of the unfinished atoms is true
                            bigOr = [ v for v in inputs ]
                            self._cnf.clauses.append(bigOr)
                            # make sure that not more than one of the unfinished atoms is true
                            for v in inputs:
                                for vp in inputs:
                                    if v < vp:
                                        self._cnf.clauses.append([-v, -vp])
                            self._cnf.clauses.append([-a])
                        inputs.clear()
                    elif any([ t.idx == last[abs(b)] for b in inputs ]):
                        todo_new = set([ b for b in inputs if abs(b) in t.vertices ])
                        inputs.difference_update(todo_new)
                        if a not in unfinished[t]:
                            unfinished[t][a] = todo_new
                        else:
                            if len(unfinished[t][a]) == 1:
                                first_lit = unfinished[t][a].pop()
                            else:
                                first_lit = aux_var()
                                if node_type == AND:
                                    bigAnd = [ first_lit ] + [ -v for v in unfinished[t][a] ]
                                    self._cnf.clauses.append(bigAnd)
                                    for v in unfinished[t][a]:
                                        self._cnf.clauses.append([ -first_lit, v ])
                                elif node_type == OR or node_type == CON:
                                    bigOr = [ -first_lit ] + [ v for v in unfinished[t][a] ]
                                    self._cnf.clauses.append(bigOr)
                                    for v in unfinished[t][a]:
                                        self._cnf.clauses.append([ first_lit, -v ])
                                elif node_type == GUESS:
                                    # remember in first_lit, whether one of the unfinished atoms is true
                                    bigOr = [ -first_lit ] + [ v for v in unfinished[t][a] ]
                                    self._cnf.clauses.append(bigOr)
                                    for v in unfinished[t][a]:
                                        self._cnf.clauses.append([ first_lit, -v ])
                                    # make sure that not more than one of the unfinished atoms is true
                                    for v in unfinished[t][a]:
                                        for vp in unfinished[t][a]:
                                            if v < vp:
                                                self._cnf.clauses.append([-v, -vp])

                            if len(todo_new) == 1:
                                second_lit = todo_new.pop()
                            else:
                                second_lit = aux_var()
                                if node_type == AND:
                                    bigAnd = [ second_lit ] + [ -v for v in todo_new ]
                                    self._cnf.clauses.append(bigAnd)
                                    for v in todo_new:
                                        self._cnf.clauses.append([ -second_lit, v ])
                                elif node_type == OR or node_type == CON:
                                    bigOr = [ -second_lit ] + [ v for v in todo_new ]
                                    self._cnf.clauses.append(bigOr)
                                    for v in todo_new:
                                        self._cnf.clauses.append([ second_lit, -v ])
                                elif node_type == GUESS:
                                    # remember in second_lit, whether one of the unfinished atoms is true
                                    bigOr = [ -second_lit ] + [ v for v in todo_new ]
                                    self._cnf.clauses.append(bigOr)
                                    for v in todo_new:
                                        self._cnf.clauses.append([ second_lit, -v ])
                                    # make sure that not more than one of the unfinished atoms is true
                                    for v in todo_new:
                                        for vp in todo_new:
                                            if v < vp:
                                                self._cnf.clauses.append([-v, -vp])
                            unfinished[t][a] = set([first_lit, second_lit])

        if not latest:
            # finalize the nodes that are left in the root
            root = td.get_root()
            for a in unfinished[root]:
                node_type = nodes[a][0]
                if node_type == AND:
                    bigAnd = [ a ] + [ -v for v in unfinished[root][a] ]
                    self._cnf.clauses.append(bigAnd)
                    for v in unfinished[root][a]:
                        self._cnf.clauses.append([ -a, v ])
                elif node_type == OR:
                    bigOr = [ -a ] + [ v for v in unfinished[root][a] ]
                    self._cnf.clauses.append(bigOr)
                    for v in unfinished[root][a]:
                        self._cnf.clauses.append([ a, -v ])
                elif node_type == CON:
                    bigOr = [ v for v in unfinished[root][a] ]
                    self._cnf.clauses.append(bigOr)
                    self._cnf.clauses.append([-a])
                elif node_type == GUESS:
                    # make sure that at least one of the unfinished atoms is true
                    bigOr = [ v for v in unfinished[root][a] ]
                    self._cnf.clauses.append(bigOr)
                    # make sure that not more than one of the unfinished atoms is true
                    for v in unfinished[root][a]:
                        for vp in unfinished[root][a]:
                            if v < vp:
                                self._cnf.clauses.append([-v, -vp])
                    self._cnf.clauses.append([-a])

        self._finalize_cnf()

    def choose_clark_completion(self):
        """Applies the clark completion to the program. 

        Does not check whether the program is tight! 
        Chooses which of the clark completions to use based on an
        approximate check for the expected treewidth of the resulting CNF.
        
        Does not return anything only constructs the cnf.
        The CNF can be obtained by using `get_cnf()`.

        The solver to compute the tree decomposition and its timeout can be specified in 
        aspmc.config.

        Returns:
            None
        """
        # approximate final width when using both/adaptive strategy
        OR = 0
        AND = 1
        CON = 2
        GUESS = 3
        INPUT = 4
        # approximate final width when using none strategy
        nodes = { a : (OR, set()) for a in self._deriv }

        cur_max = self._max
        for a in self._exactlyOneOf:
            cur_max += 1
            nodes[cur_max] = (GUESS, set(abs(v) for v in a))

        for atom in self._guess:
            nodes[atom] = (INPUT, set())

        for r in self._program:
            cur_max += 1
            nodes[cur_max] = (AND, set(abs(v) for v in r.body))
            if len(r.head) != 0:
                nodes[abs(r.head[0])][1].add(cur_max)

        # set up the and/or graph
        graph = nx.Graph()
        for a, inputs in nodes.items():
            graph.add_edges_from([ (a, v) for v in inputs[1] ])
            graph.add_edges_from([ (a + cur_max, v) for v in inputs[1] ])
            
        td = treedecomposition.from_graph(graph, solver = config.config["decos"], timeout = str(float(config.config["decot"])/3))
        cost_both = 0
        for t in td.bag_iter():
            for tp in t.children:
                tp.parent = t

        for t in td.bag_iter():
            cur_cost = len(t.vertices)
            if t != td.get_root():
                kept = t.vertices.intersection(t.parent.vertices)
            else:
                kept = set()
            occ_counter = { a : 0 for a in t.vertices }
            for a in kept: 
                occ_counter[a] += 1
            for tp in t.children:
                for a in tp.vertices.intersection(t.vertices):
                    occ_counter[a] += 1
            for a, c in occ_counter.items():
                if c >= 3 and a > cur_max and nodes[a - cur_max][0] != INPUT:
                    cur_cost += 1
            if cur_cost > cost_both:
                cost_both = cur_cost

        # set up the and/or graph
        graph = nx.Graph()
        for a, inputs in nodes.items():
            graph.add_edges_from(sum([ [ (v, vp) for v in inputs[1] if v != vp ] for vp in inputs[1] ], []))
            graph.add_edges_from([ (a, v) for v in inputs[1] ])
            
        td = treedecomposition.from_graph(graph, solver = config.config["decos"], timeout = str(float(config.config["decot"])/3))
        cost_none = td.width

        # approximate final width when using or strategy
        # set up the and/or graph
        graph = nx.Graph()
        for a, inputs in nodes.items():
            if inputs[0] == AND or inputs[0] == GUESS:
                graph.add_edges_from(sum([ [ (v, vp) for v in inputs[1] if v != vp ] for vp in inputs[1] ], []))
            graph.add_edges_from([ (a, v) for v in inputs[1] ])
            if inputs[0] != AND and inputs[0] != GUESS:
                graph.add_edges_from([ (a + cur_max, v) for v in inputs[1] ])
            
        td = treedecomposition.from_graph(graph, solver = config.config["decos"], timeout = str(float(config.config["decot"])/3))
        cost_or = 0
        for t in td.bag_iter():
            for tp in t.children:
                tp.parent = t

        for t in td.bag_iter():
            cur_cost = len(t.vertices)
            if t != td.get_root():
                kept = t.vertices.intersection(t.parent.vertices)
            else:
                kept = set()
            occ_counter = { a : 0 for a in t.vertices }
            for a in kept: 
                occ_counter[a] += 1
            for tp in t.children:
                for a in tp.vertices.intersection(t.vertices):
                    occ_counter[a] += 1
            for a, c in occ_counter.items():
                if c >= 3 and a > cur_max and nodes[a - cur_max][0] != INPUT and nodes[a - cur_max][0] != AND and nodes[a - cur_max][0] != GUESS:
                    cur_cost += 1
            if cur_cost > cost_or:
                cost_or = cur_cost

        logger.debug(f"Approximate expected treewidth using strategy none: {cost_none}")
        logger.debug(f"Approximate expected treewidth using strategy or: {cost_or}")
        logger.debug(f"Approximate expected treewidth using strategy both/adaptive: {cost_both}")
        logger.info("------------------------------------------------------------")

        if cost_none <= min(cost_both, cost_or) + 1:
            logger.info(f"Choosing Unguided Clark Completion")
            logger.info("------------------------------------------------------------")
            self.clark_completion()
        elif cost_or <= cost_both + 1:
            logger.info(f"Choosing OR-guided Clark Completion")
            logger.info("------------------------------------------------------------")
            self.td_guided_clark_completion()
        else:
            logger.info(f"Choosing completely guided Clark Completion")
            logger.info("------------------------------------------------------------")
            self.td_guided_both_clark_completion(adaptive=True, latest = True)
    
    def build_bdds(self):
        from dd.cudd import BDD
        bdd = BDD()
        bdd.declare(*[ self._internal_name(v) for v in self._guess ])
        # set up the and/or graph
        graph = nx.DiGraph()
        for r in self._program:
            if len(r.body) > 0:
                for atom in r.head:
                    graph.add_edge(r, atom)
                for atom in r.body:
                    graph.add_edge(abs(atom), r)
        vertex_to_bdd = { v : bdd.var(self._internal_name(v)) for v in self._guess }
        ts = nx.topological_sort(graph)
        for cur in ts:
            if isinstance(cur, Rule):
                new_bdd = vertex_to_bdd[cur.body[0]]
                for b in cur.body[1:]:
                    new_bdd = new_bdd & vertex_to_bdd[b]
                vertex_to_bdd[cur] = new_bdd
            elif cur not in self._guess:
                ins = list(graph.in_edges(nbunch=cur))
                new_bdd = vertex_to_bdd[ins[0][0]]
                for r in ins[1:]:
                    new_bdd = new_bdd | vertex_to_bdd[r[0]]
                vertex_to_bdd[cur] = new_bdd
        return vertex_to_bdd

    def build_sdds(self):
        from aspmc.compile.vtree import TD_to_vtree
        from pysdd.sdd import Vtree, SddManager
        import tempfile
        # first generate a vtree for the program that is probably good
        OR = 0
        AND = 1
        GUESS = 3
        INPUT = 4
        # approximate final width when using none strategy
        nodes = { a : (OR, set()) for a in self._deriv }

        cur_max = self._max
        for a in self._exactlyOneOf:
            cur_max += 1
            nodes[cur_max] = (GUESS, set(abs(v) for v in a))

        for atom in self._guess:
            nodes[atom] = (INPUT, set())

        for r in self._program:
            cur_max += 1
            nodes[cur_max] = (AND, set(abs(v) for v in r.body))
            if len(r.head) != 0:
                nodes[abs(r.head[0])][1].add(cur_max)

        # set up the and/or graph
        graph = nx.Graph()
        for a, inputs in nodes.items():
            graph.add_edges_from([ (a, v) for v in inputs[1] ])
            
        td = treedecomposition.from_graph(graph, solver = config.config["decos"], timeout = str(float(config.config["decot"])))
        td.remove(set(range(1, cur_max + 1)).difference(self._guess))
        td.get_root().vertices.update(self._guess)
        my_vtree = TD_to_vtree(td)
        guesses = list(self._guess)
        rev_mapping = { guesses[i] : i + 1 for i in range(len(self._guess)) }
        for node in my_vtree:
            if node.val != None:
                assert(node.val in self._guess)
                node.val = rev_mapping[node.val]

        (_, vtree_tmp) = tempfile.mkstemp()
        my_vtree.write(vtree_tmp)
        vtree = Vtree(filename=vtree_tmp)
        sdd = SddManager.from_vtree(vtree)
        vars = list(sdd.vars)
        os.remove(vtree_tmp)
        vertex_to_sdd = { v : vars[i] for i,v in enumerate(guesses) }

        # set up the and/or graph
        graph = nx.DiGraph()
        for r in self._program:
            for atom in r.head:
                graph.add_edge(r, atom)
            for atom in r.body:
                graph.add_edge(abs(atom), r)

        # build the relevant sdds by traversing the graph in topological order
        ts = nx.topological_sort(graph)
        for cur in ts:
            if isinstance(cur, Rule):
                new_sdd = sdd.true()
                for b in cur.body:
                    if b < 0:
                        vertex_to_sdd[b] = ~vertex_to_sdd[-b]
                    new_sdd = new_sdd & vertex_to_sdd[b]
                vertex_to_sdd[cur] = new_sdd
            elif cur not in self._guess:
                ins = list(graph.in_edges(nbunch=cur))
                new_sdd = sdd.false()
                for r in ins:
                    new_sdd = new_sdd | vertex_to_sdd[r[0]]
                vertex_to_sdd[cur] = new_sdd

        return vertex_to_sdd

    def to_aig(self, path):
        varMap = { name : var for var, name in self._nameMap.items() }
        inputs = "\n".join( str(2*(i+1)) for i,v in enumerate(self._guess) )
        nr_ands = 0
        cur_idx = len(self._guess)
        and_aig = ""
        graph = nx.DiGraph()
        for r in self._program:
            if len(r.body) > 0:
                for atom in r.head:
                    graph.add_edge(r, atom)
                for atom in r.body:
                    graph.add_edge(abs(atom), r)
        ts = nx.topological_sort(graph)
        vertex_to_var = { v : 2*(i+1) for i,v in enumerate(self._guess) }
        for cur in ts:
            if isinstance(cur, Rule):
                if cur.body[0] < 0:
                    vertex_to_var[cur.body[0]] = vertex_to_var[-cur.body[0]] ^ 1
                new_bdd = vertex_to_var[cur.body[0]]
                for b in cur.body[1:]:
                    if b < 0:
                        vertex_to_var[b] = vertex_to_var[-b] ^ 1
                    nr_ands += 1
                    cur_idx += 1
                    and_aig += f"{2*cur_idx} {new_bdd} {vertex_to_var[b]}\n"
                    new_bdd = 2*cur_idx
                vertex_to_var[cur] = new_bdd
            elif cur not in self._guess:
                ins = list(graph.in_edges(nbunch=cur))
                new_bdd = vertex_to_var[ins[0][0]] ^ 1
                for r in ins[1:]:
                    nr_ands += 1
                    cur_idx += 1
                    and_aig += f"{2*cur_idx} {new_bdd} {vertex_to_var[r[0]] ^ 1}\n"
                    new_bdd = 2*cur_idx
                vertex_to_var[cur] = new_bdd ^ 1
        outputs = "\n".join( str(vertex_to_var[varMap[name]]) for name in self.get_queries() )
        with open(path, "w") as out_file:
            out_file.write(f"aag {cur_idx} {len(self._guess)} 0 {len(self.get_queries())} {nr_ands}\n")
            out_file.write(f"{inputs}\n")
            out_file.write(f"{outputs}\n")
            out_file.write(f"{and_aig}")

    def _finalize_cnf(self):
        for l in self._copies.values():
            self._cnf.auxilliary.update(l)
        self._cnf.auxilliary.update(self._auxilliary)

    def encoding_stats(self):
        """Print the stats of a tree decomposition of the cnf. 

        Returns:
            None        
        """
        primal = Hypergraph()
        primal.add_nodes_from(range(1, self._max + 1))
        primal.add_edges_from([ set([ abs(x) for x in c ]) for c in self._cnf.clauses ])
        td = treedecomposition.from_hypergraph(primal)
        logger.info(f"Tree Decomposition #bags: {td.bags} CNF treewidth: {td.width} #vertices: {td.vertices}")      

    def get_cnf(self):
        """Used to get the extended cnf corresponding to the program. 

        Only possible after having called `tpUnfold()` and one of the Clark completion methods.
        
        Returns:
            :obj:`aspmc.compile.cnf.CNF`: Returns the extended cnf of the program.        
        """
        return self._cnf

    def write_dimacs(self, stream, **kwargs):
        """Write the extended cnf corresponding to the program to a stream. 

        Only possible after having called `tpUnfold()` and one of the Clark completion methods.
        
        Args:
            stream (:obj:`stream`): The stream to write to. Must be binary.

        Returns:
            :obj:`aspmc.compile.cnf.CNF`: Returns the extended cnf of the program.        
        """
        # FIXME: this does not work anymore since auxilliary cnf vars do not have a name
        # if "debug" in kwargs:
        #     stream.write(f"p cnf {self._cnf.nr_vars} {len(self._cnf.clauses)}\n".encode())
        #     for c in self._cnf.clauses:
        #         stream.write((" ".join([("(not " if v < 0 else "(") + self._external_name(abs(v)) + ")" for v in c]) + " 0\n" ).encode())
        # else:
        self._cnf.to_stream(stream)


    def _prog_string(self, program):
        """Get a string representation of a part of the program. 

        Should be overwritten by subclasses.

        Args:
            program (:obj:`list`): List of rules that should be printed.

        Returns:
            :obj:`string`: A string representation of the rules in `program`.        
        """
        result = ""
        for r in self._exactlyOneOf:
            result += f"1{{{';'.join([ self._external_name(v) for v in r ])}}}1.\n"
        for g in self._guess:
            result += f"{{{self._external_name(g)}}}.\n"
        for r in program:
            result += ";".join([ self._external_name(v) for v in r.head ])
            if len(r.body) > 0:
                result += ":-"
                result += ",".join([("not " if v < 0 else "") + self._external_name(abs(v)) for v in r.body])
            result += ".\n"
        return result

    def write_prog(self, stream, spanning = False):
        """Write the (spanning) program to a stream.

        Args:
            stream (:obj:`stream`): The stream to write to. Must be binary.
            spanning (:obj:`bool`, optional): Whether the to write (case `False`) the actual program,
                possibly with weights and utilities and such or (case `True`) only the spanning program. 
                The spanning program corresponds to the underlying logical theory.
                Defaults to `False`.

        Returns:
            None     
        """
        if spanning:
            stream.write(Program._prog_string(self, self._program).encode())
        else:
            stream.write(self._prog_string(self._program).encode())

    def get_weights(self):
        """Get the weights of all the literals. 

        Should be overwritten by subclasses.

        Returns:
            :obj:`list`: A list of `weights` as numpy arrays.      
                The weight of literal `v` is in `weights[2*(v-1)]`, the one for `-v` is in `weights[2*(v-1)+1]`
        """
        return [ np.array([1.0]) for _ in range(self._max*2) ]

    def get_queries(self):
        """Get the queries (names not literals). 

        Should be overwritten by subclasses.

        Returns:
            :obj:`list`: A list of queries. 
                The empty list corresponds to asking for the overall weight of the program.
        """
        return []

