import sys
sys.path.append("..")

import torch
import networkx as nx

from circuits.node import Node
from circuits.dnnf import DNNF
from circuits.utils import caching_reset, caching, logsumexp, logsum, log1mexp

def graph2circuit(G, T, F, pos, neg, root=None, internal_nodes=None):
    if root is None:
        sources = [v for v in G.nodes() if G.in_degree(v)==0]
        root=sources[0]

    if internal_nodes is None:
        internal_nodes={}

    types=nx.get_node_attributes(G, "type")
    lits=nx.get_edge_attributes(G, "lits")

    # print(root, types[root])

    if types[root]=="t":
        internal_nodes[root] = T
        return T, internal_nodes

    elif types[root]=="f":
        internal_nodes[root] = F
        return F, internal_nodes

    else:
        children=[]
        for edge in G.out_edges(root):
            # print(edge)
            if len(edge)==2:
                edge=edge+(0,)
            if edge[1] in internal_nodes:
                node=internal_nodes[edge[1]]
            else:
                node, internal_nodes = graph2circuit(G, T, F, pos, neg, root=edge[1], internal_nodes=internal_nodes)

            if types[root]=="o" and len(lits[edge])>0:
                literals=[]
                for lit in lits[edge]:
                    if int(lit)>0:
                        literals.append(pos[int(lit)])
                    else:
                        literals.append(neg[abs(int(lit))])
                literals.append(node)
                children.append(dDNNF(type=Node.Types.AND, children=literals))

            else:
                children.append(node)

        if len(children)==1:
            print("Only one children at : ", root)

        if types[root]=="o":
            node = dDNNF(type=Node.Types.OR, children=children)
            internal_nodes[root] = node
            return node, internal_nodes

        elif types[root]=="a":
            node = dDNNF(type=Node.Types.AND, children=children)
            internal_nodes[root] = node
            return node, internal_nodes

class dDNNF(DNNF):
    def __init__(self, type, children=None, id=None, var=None):
        if children:
            for c in children:
                if not(isinstance(c, dDNNF)):
                    raise Exception("Children must be dDNNFs !")

        super().__init__(type=type, children=children, id=id, var=var)

        if not(self.is_DNNF()):
            raise Exception("Must be a DNNF !")

    @classmethod
    def from_node(cls, node):
        if node.is_leaf():
            if node.is_var():
                return dDNNF(type=node.type, children=node.children, id=node.id, var=node.var)
            else:
                return dDNNF(type=node.type, children=node.children, id=node.id)
        else:
            children=[c if isinstance(c, dDNNF) else dDNNF.from_node(c) for c in node.children]
            return dDNNF(type=node.type, children=children, id=node.id)

    @classmethod
    def from_file(cls, filepath):
        with open(filepath) as f:
            lines = f.readlines()
            L = len(lines)+1
            g = nx.MultiDiGraph()
            literals = []
            lit2idx = {}
            for line in lines:
                e = line.split()
                if e[0].isnumeric():
                    g.add_edge(e[0], e[1], lits=e[2:-1])
                    literals += e[2:-1]
                else:
                    g.add_node(e[1], type=e[0])

        T = dDNNF(type=Node.Types.T)
        F = dDNNF(type=Node.Types.F)

        pos={}
        neg={}
        literals = set(literals)
        for lit in literals:
            var=abs(int(lit))
            if not(var in pos):
                pos[var] = dDNNF(type=Node.Types.VAR, var=var-1)
                neg[var] = dDNNF(type=Node.Types.NEG, children=[pos[var]])

        lits=nx.get_edge_attributes(g, "lits")
        todel = [u for u in g.nodes() if (len(list(g.out_edges(u)))==1 and len(lits[list(g.out_edges(u))[0]+(0,)])==0)]
        while len(todel)!=0:
            u=todel[0]
            v=list(g.out_edges(u))[0][1]
            g.add_edges_from([(e[0], v) for e in g.in_edges(u)])
            g.remove_node(u)
            todel = [u for u in g.nodes() if (len(list(g.out_edges(u)))==1 and len(lits[list(g.out_edges(u))[0]+(0,)])==0)]
        
        node, _ = graph2circuit(g, T, F, pos, neg)
        return node

    @classmethod
    def evidence2term(cls, evidence, lits):
        children=[]
        for key, value in evidence.items():
            lit=key*value
            if not(lit in lits):
                lits[key] = dDNNF(type=Node.Types.VAR, var=key)
                lits[-key] = dDNNF(type=Node.Types.NEG, children=[lits[key]])

            children.append(lits[lit])

        if len(children)>1:
            return dDNNF(type=Node.Types.AND, children=children)
        else:
            return children[0]

    # @caching
    # def log_pqe(self, logprobs, first_call=True):
    #     #TODO:Add evidence
    #     """
    #     Performs PQE in log space.

    #     Inputs:
    #         - logprobs : a torch tensor of logprobs of shape (batch_size x num_variables)
        
    #     Outputs :
    #         - logp : a torch tensor of shape (batch_size) representing for the log-probability of the circuit under the independent multi-label distribution parameterized by probs.
    #     """
    #     if first_call:
    #         self.clear_cache(key="log_pqe")

    #     bs=logprobs.shape[0]

    #     if self.is_true():
    #         logp = torch.zeros(bs, dtype=torch.float64)
    #     elif self.is_false():
    #         logp = torch.full((bs,), -300, dtype=torch.float64)
    #     elif self.is_var():
    #         logp = logprobs[:,self.var]
    #     elif self.is_neg():
    #         logp = log1mexp(self.children[0].log_pqe(logprobs, first_call=False))
    #     elif self.is_and():
    #         logp = torch.sum(torch.stack([c.log_pqe(logprobs, first_call=False) for c in self.children], dim=1), dim=1)
    #     elif self.is_or():
    #         logp = logsumexp(torch.stack([c.log_pqe(logprobs, first_call=False) for c in self.children], dim=1), dim=1)
    #         # if torch.any(torch.isinf(logp)) and self.sat():
    #         #     print(self.string())
    #         #     raise Exception("Problem with probabilities")

    #     return logp


    def log_pqe(self, logprobs):
        #TODO:Add evidence
        """
        Performs PQE in log space.

        Inputs:
            - logprobs : a torch tensor of logprobs of shape (batch_size x num_variables)
        
        Outputs :
            - logp : a torch tensor of shape (batch_size) representing for the log-probability of the circuit under the independent multi-label distribution parameterized by probs.
        """
        bs=logprobs.shape[0]

        for node in self.iter():
            if node.is_true():
                logp = torch.zeros(bs, dtype=torch.float64)
            elif node.is_false():
                logp = torch.full((bs,), -300, dtype=torch.float64)
            elif node.is_var():
                logp = logprobs[:,node.var]
            elif node.is_neg():
                logp = log1mexp(node.children[0].cache["log_pqe"])
            elif node.is_and():
                try:
                    logp = torch.sum(torch.stack([c.cache["log_pqe"] for c in node.children], dim=1), dim=1)
                except:
                    print("Error !")
                    print("Types were : ", [c.type for c in node.children])
                    print("Shapes were : ", [c.cache["log_pqe"].shape for c in node.children])
                    sys.exit(1)
            elif node.is_or():
                logp = logsumexp(torch.stack([c.cache["log_pqe"] for c in node.children], dim=1), dim=1)
                # if torch.any(torch.isinf(logp)) and self.sat():
                #     print(self.string())
                #     raise Exception("Problem with probabilities")

            node.cache["log_pqe"]=logp

        logp=self.cache["log_pqe"]
        self.clear_cache("log_pqe")

        return logp
        

    @caching
    def pqe(self, probs):
        #TODO:Add evidence
        """
        Performs PQE.

        Inputs:
            - probs : a torch tensor of probabilities of shape (batch_size x num_variables)
        
        Outputs :
            - p : a torch tensor of shape (batch_size) representing for the probability of the circuit under the independent multi-label distribution parameterized by probs.
        """
        bs=probs.shape[0]

        if self.is_true():
            return torch.ones(bs)
        elif self.is_false():
            return torch.zeros(bs)
        elif self.is_var():
            return probs[:,self.var]
        
        else:
            if not(self.is_dDNNF()):
                raise Exception("Only implemented for dDNNFs !")
            
            if self.is_neg():
                return 1 - self.children[0].pqe(probs)
            
            elif self.is_and():
                return torch.prod(torch.stack([c.pqe(probs) for c in self.children], dim=1), dim=1)

            elif self.is_or():
                return torch.sum(torch.stack([c.pqe(probs) for c in self.children], dim=1), dim=1)


    def distribute_evidence(self, e):
        """
        Distribute evidence by setting the value of the cache in all nodes of the circuits.

        Inputs:
            - e : a (batch_size x nb_vars) tensor that represents evidence with 1, 0, -1 values
        """

        if torch.all(e==0):
            self.cache["evidence"]=e
            if not(self.is_leaf()):
                for c in self.children:
                    c.distribute_evidence(e)

        else:
            if self.is_true:
                self.cache["evidence"]=e
            elif self.is_false:
                self.cache["evidence"]=torch.zeros_like(e)
            elif self.is_neg():
                self.cache["evidence"]=e
                self.children[0].distribute_evidence(torch.zeros_like(e))
            elif self.is_and():
                mask = torch.ones_like(e)
                mask[:, self.vars()]=0
                self.cache["evidence"]=torch.mul(mask, e)
                for c in self.children:
                    mask = torch.zeros_like(e)
                    mask[:, c.vars()]=1
                    c.distribute_evidence(torch.mul(mask, e))
            elif self.is_or():
                self.cache["evidence"]=torch.zeros_like(e)
                for c in self.children:
                    c.distribute_evidence(e)

    def log_pqe_evidence(self, logprobs, e):
        #TODO:Add evidence
        """
        Performs PQE in log space with evidence.

        Inputs:
            - logprobs : a torch tensor of logprobs of shape (batch_size x num_variables)
            - e : a (batch_size x nb_vars) tensor that represents evidence with 1, 0, -1 values
        
        Outputs :
            - logp : a torch tensor of shape (batch_size) representing for the log-probability of the circuit under the independent multi-label distribution parameterized by probs.
        """
        bs=logprobs.shape[0]
        if e.shape[0]==1:
            e = torch.tile((bs, -1))
        
        if e.shape[0]!=bs:
            raise Exception("Evidences must have the same batch size than logprobs")

        self.distribute_evidence(e)

        for node in self.iter():
            if node.is_true():
                logp = torch.zeros(bs, dtype=torch.float64)
            elif node.is_false():
                logp = torch.full((bs,), -300, dtype=torch.float64)
            elif node.is_var():
                logp = logprobs[:,node.var]
            elif node.is_neg():
                logp = log1mexp(node.children[0].cache["log_pqe"])
            elif node.is_and():
                try:
                    logp = torch.sum(torch.stack([c.cache["log_pqe"] for c in node.children], dim=1), dim=1)
                except:
                    print("Error !")
                    print("Types were : ", [c.type for c in node.children])
                    print("Shapes were : ", [c.cache["log_pqe"].shape for c in node.children])
                    sys.exit(1)
            elif node.is_or():
                logp = logsumexp(torch.stack([c.cache["log_pqe"] for c in node.children], dim=1), dim=1)
                # if torch.any(torch.isinf(logp)) and self.sat():
                #     print(self.string())
                #     raise Exception("Problem with probabilities")

            # Add the log probability of distributed evidence at that node
            logp = torch.add(logp, torch.sum(torch.where(node.cache["evidence"]==1, logprobs, 0)+torch.where(node.cache["evidence"]==-1, log1mexp(logprobs), 0), dim=1))
            node.cache["log_pqe"]=logp

        logp=self.cache["log_pqe"]
        self.clear_cache(keys=["log_pqe", "evidence"])
        return logp


    # SDD
    def elements(self):
        if not(self.is_OBDD()):
            raise Exception("Only implemented for OBDDs !")
        if not(self.is_or()):
            raise Exception("Only implemented for OR nodes !")

        elements=[(c.children[0], c.children[1]) for c in self.children]
        return elements

    def sdd_iter(self,first_call=True, post=True):
        """Generator of sdd nodes, post or pre order"""
        if not(self.is_OBDD()):
            raise Exception("Only implemented for OBDDs !")

        if not(self.cache["bit"]):
            self.cache["bit"] = True

            if not(post):
                yield self

            if self.is_or():
                for p,s in self.elements():
                    for node in p.sdd_iter(first_call=False, post=post): yield node
                    for node in s.sdd_iter(first_call=False, post=post): yield node

            if post:
                yield self

        if first_call:
            self.clear_cache(key="bit")

    def set_sdd_ids(self):
        for (i, n) in enumerate(self.sdd_iter(post=False)):
            n.id = i
            count = i
        return count

    def set_vtree(self, vtree):
        if self.vtree is None:
            self.vtree = vtree
            if self.is_or():
                for p,s in self.elements():
                    p.set_vtree(vtree.left)
                    s.set_vtree(vtree.right)

    def check_vtree(self, vtree):
        if self.is_or():
            for p,s in self.elements():
                if not(p.check_vtree(vtree.left)):
                    return False
                if not(s.check_vtree(vtree.right)):
                    return False
            return True

        elif self.is_lit():
            if not(vtree.is_leaf()):
                return False
            var = self.vars()[0]
            if not(var==vtree.var):
                return False
            return True

        elif self.is_false() or self.is_true():
            return True

        else:
            return False


    def sdd_repr(self):
        if self.is_false():
            st = 'F %d' % self.id
        elif self.is_true():
            st = 'T %d' % self.id
        elif self.is_var():
            st = 'L %d %d %d' % (self.id,self.vtree.id,self.var+1)
        elif self.is_neg():
            st = 'L %d %d %d' % (self.id,self.vtree.id,-(self.children[0].var+1))
        elif self.is_or():
            els = self.elements()
            st_el = " ".join( '%d %d' % (p.id,s.id) for p,s in els )
            st = 'D %d %d %d %s' % (self.id,self.vtree.id,len(els),st_el)
        return st

    def sdd_save(self, filename, vtree):
        count=self.set_sdd_ids()
        self.set_vtree(vtree)

        _sdd_file_header = \
            ("c ids of sdd nodes start at 0\n"
            "c sdd nodes appear bottom-up, children before parents\n"
            "c\n"
            "c file syntax:\n"
            "c sdd count-of-sdd-nodes\n"
            "c F id-of-false-sdd-node\n"
            "c T id-of-true-sdd-node\n"
            "c L id-of-literal-sdd-node id-of-vtree literal\n"
            "c D id-of-decomposition-sdd-node id-of-vtree"
            " number-of-elements {id-of-prime id-of-sub}*\n"
            "c\n")

        with open(filename,'w') as f:
            f.write(_sdd_file_header)
            f.write('sdd %d\n' % int(count+1))
            for n in self.sdd_iter():
                f.write('%s\n' % n.sdd_repr())