import os
import sys
sys.setrecursionlimit(10000)
sys.path.append("..")
sys.path.append(os.path.join(sys.path[0], '..' , 'SPL', 'grids'))
sys.path.append(os.path.join(sys.path[0], '..' ,'SPL','grids', 'pypsdd'))

import networkx as nx
from circuits.node import Node
from circuits.ddnnf import dDNNF
from pypsdd import Vtree

def topological_sort_edges(G):
    if not(isinstance(G, nx.DiGraph)):
        raise Exception("Only implemented for networkx DiGraph !")
    if not(nx.is_directed_acyclic_graph(G)):
        raise Exception("Only implemented for Directed Acyclic Graphs !")

    order = []
    for v in nx.topological_sort(G):
        for e in G.out_edges(v):
            order.append(e)

    return order


def AcyclicSimplePath(D):
    if not(isinstance(D, nx.DiGraph)):
        raise Exception("Only implemented for networkx DiGraph !")
    if not(nx.is_directed_acyclic_graph(D)):
        raise Exception("Only implemented for Directed Acyclic Graphs !")

    G = D.copy()

    sources = [v for v in G.nodes() if G.in_degree(v)==0]
    s=sources[0]
    targets = [v for v in G.nodes() if G.out_degree(v)==0]
    t=targets[0]
    

    # Merge all source vertices
    if len(sources)>1:
        toremove = []
        for v in sources[1:]:
            outedges = G.out_edges(v)
            G.add_edges_from([(s, e[1]) for e in outedges])
            # G.remove_edges_from(outedges)
            toremove += outedges
        G.remove_edges_from(toremove)
        G.remove_nodes_from(sources[1:])

    # Merge all target vertices
    if len(targets)>1:
        toremove = []
        for v in targets[1:]:
            inedges = G.in_edges(v)
            G.add_edges_from([(e[0], t) for e in inedges])
            toremove += inedges
        G.remove_edges_from(toremove)
        G.remove_nodes_from(targets[1:])

    V = list(nx.topological_sort(G))
    E = list(topological_sort_edges(G))

    nodes = {i:{} for i in range(len(E))}

    T = dDNNF(type=Node.Types.T)
    F = dDNNF(type=Node.Types.F)

    for (i,e) in enumerate(E):
        pos=dDNNF(type=Node.Types.VAR, var=i)
        neg=dDNNF(type=Node.Types.NEG, children=[pos])
        for (j,v) in enumerate(V):
            if i==0:
                if v==s:
                    # left=dDNNF(type=Node.Types.AND, children=[pos, F])
                    # right=dDNNF(type=Node.Types.AND, children=[neg, T])
                    # nodes[1][v]=dDNNF(type=Node.Types.OR, children=[left, right])
                    nodes[0][v]=neg
                elif e==(s, v):
                    # left=dDNNF(type=Node.Types.AND, children=[pos, T])
                    # right=dDNNF(type=Node.Types.AND, children=[neg, F])
                    # nodes[1][v]=dDNNF(type=Node.Types.OR, children=[left, right])
                    nodes[0][v]=pos
                else:
                    # left=dDNNF(type=Node.Types.AND, children=[pos, F])
                    # right=dDNNF(type=Node.Types.AND, children=[neg, F])
                    nodes[0][v]=F

            else:
                if e[1]==v: # if the target vertex of edge e is vertex v
                    # if nodes[i][e[0]]==F and [neg, nodes[i][v]]==F:
                    if nodes[i-1][e[0]]==F and nodes[i-1][v]==F:
                        nodes[i][v]=F
                    else:
                        left=dDNNF(type=Node.Types.AND, children=[pos, nodes[i-1][e[0]]])
                        right=dDNNF(type=Node.Types.AND, children=[neg, nodes[i-1][v]])
                        nodes[i][v]=dDNNF(type=Node.Types.OR, children=[left, right])
                else:
                    if nodes[i-1][v]==F:
                        nodes[i][v]=F
                    else:
                        left=dDNNF(type=Node.Types.AND, children=[pos, F])
                        right=dDNNF(type=Node.Types.AND, children=[neg, nodes[i-1][v]])
                        nodes[i][v]=dDNNF(type=Node.Types.OR, children=[left, right])

    return nodes, G, E

def grid(n):
    g = nx.DiGraph()
    for i in range(n):
        for j in range(n):
            if i<n-1:
                g.add_edge((i, j), (i+1, j))
            if j<n-1:
                g.add_edge((i, j), (i, j+1))

    return g

def right_linear_vtree(n):
    for i in range(1, n):
        if i==1:
            right=Vtree.leaf_node(var=i)
        else:
            right=Vtree.internal_node(left=left, right=right)
        left=Vtree.leaf_node(var=i+1)

    return Vtree.internal_node(left=left, right=right)


# n=2
# nodes, G, order = AcyclicSimplePath(grid(n))
# N = nodes[len(order)-1][(n-1, n-1)]
# vtree=right_linear_vtree(len(order))
# for (i, v) in enumerate(vtree.__iter__()):
#     v.id=i
# vtree.save(filename="{0}x{0}.vtree".format(n))
# N.sdd_save(filename="{0}x{0}.sdd".format(n), vtree=vtree)