from enum import Enum

import torch
import networkx as nx

from circuits.utils import caching, logsumexp, logsum, log1mexp

class Node():
    class Types(Enum):
        F, T, VAR, NEG, AND, OR = range(1, 7)

    def __init__(self, type, children=None, id=None, var=None):
        if not(type in Node.Types):
            raise Exception("The type of the node must be in : ", list(Node.Types))
        self.type  = type
        self.id = id
        self.cache = {}
        self.unique = None
        self.cache["bit"] = False

        if children is None:
            children = []
        else:
            for c in children:
                if not(isinstance(c, Node)):
                    raise Exception("Children nodes must be of Node type")

        if self.is_binary() and (len(children)<=1):
            raise Exception("AND/OR nodes must have at least two children")
        elif (self.is_neg()) and len(children)!=1:
            raise Exception("NEG nodes must have exactly one children")
        elif (self.is_leaf()) and len(children)>1:
            raise Exception("FALSE/TRUE/VAR nodes must have no children")
        else:
            self.children=children

        if (self.is_var()):
            if (var is None):
                raise Exception("LEAF nodes must have a variable")
            else:
                self.var=var

        self.vtree = None

    # @classmethod
    # def evidence2term(cls, evidence, manager=None):
    #     if manager is None:
    #         add = Node
    #     else:
    #         add = manager.add
        
    #     children=[]
    #     for key, value in evidence.items():
    #         if value==1:
    #             children.append(add(type=Node.Types.VAR, var=key))
    #         if value==-1:
    #             children.append(add(type=Node.Types.NEG, children=[add(type=Node.Types.VAR, var=key)]))

    #     return add(type=Node.Types.AND, children=children)

    @classmethod
    def from_dimacs(cls, filename):
        with open(filename) as f:
            lines = f.readlines()
            infos = lines[0]
            _, _, num_vars, num_clauses = infos.split()

            lit_nodes = []
            for v in range(1, int(num_vars)+1):
                lit=Node(type=Node.Types.VAR, var=v)
                lit_nodes.append(lit)
                lit_nodes.append(Node(type=Node.Types.NEG, children=[lit]))

            clauses= []
            for line in lines[1:]:
                lits = line.split()
                children = []
                for lit in lits[:-1]:
                    if lit[0]=="-":
                        v = int(lit[1:])
                        children.append(lit_nodes[2*v-1])
                    else:
                        v = int(lit)
                        children.append(lit_nodes[2*(v-1)])
                clauses.append(Node(type=Node.Types.OR, children=children))
        
        return Node(type=Node.Types.AND, children=clauses)

    # TYPES
    def is_true(self):
        if self.type is Node.Types.T:
            return True
        else:
            return False
    def is_false(self):
        if self.type is Node.Types.F:
            return True
        else:
            return False
    def is_var(self):
        if self.type is Node.Types.VAR:
            return True
        else:
            return False
    def is_neg(self):
        if self.type is Node.Types.NEG:
            return True
        else:
            return False
    def is_or(self):
        if self.type is Node.Types.OR:
            return True
        else:
            return False
    def is_and(self):
        if self.type is Node.Types.AND:
            return True
        else:
            return False

    @caching
    def is_leaf(self):
        if (self.is_true()) or (self.is_false()) or (self.is_var()):
            return True
        else:
            return False

    @caching
    def is_binary(self):
        if (self.is_or()) or (self.is_and()):
            return True
        else:
            return False

    @caching
    def is_lit(self):
        if (self.is_var()):
            return True
        elif (self.is_neg()):
            if (self.children[0].type is Node.Types.VAR):
                return True
        else:
            return False

    @caching
    def is_decision(self):
        if not(self.is_and()):
            return False
        elif (len(self.children) == 2) and self.children[0].is_lit():
            return True
        else:
            return False
    
    @caching
    def is_leftneg(self):
        if not(self.is_and()):
            return False
        elif len(self.children) != 2:
            return False
        elif self.children[0].type=="NEG":
            if self.children[0].children[0].type is Node.Types.VAR:
                return True
        return False

    @caching
    def is_clause(self):
        if self.is_lit():
            return True
        elif self.is_or():
            for c in self.children:
                if not(c.is_clause()):
                    return False
                return True
        else:
            return False

    @caching
    def is_term(self):
        if self.is_lit():
            return True
        elif self.is_and():
            for c in self.children:
                if not(c.is_term()):
                    return False
                return True
        else:
            return False

    @caching
    def is_NNF(self):
        if self.is_leaf() or self.is_lit():
            return True
        elif self.is_binary():
            for c in self.children:
                if not(c.is_NNF()):
                    return False
            return True
        else:
            return False

    @caching
    def is_DNNF(self):
        if self.is_lit():
            return True
        elif not(self.is_NNF()):
            # print(self.string())
            return False
        else: # remains types AND/OR in NNF
            visited=[]
            for c in self.children:
                if not(c.is_DNNF()):
                    # print(self.string())
                    return False
                if (self.is_and()):
                    for v in c.vars():
                        if v in visited:
                            # print(self.string())
                            return False
                        visited.append(v)

            return True

    @caching
    def is_OBDD(self, order=None):
        if order is None:
            order=[]
            for n in self.iter():
                if n.type is Node.Types.VAR and not(n.var in order):
                    order.append(n.var)
        
        v = order[0]
        rest = order[1:]

        if self.is_leaf():
            return True
        elif self.is_var():
            if self.var==v:
                return True
            else:
                return False
        elif self.is_neg():
            if not(self.children[0].type is Node.Types.VAR):
                return False
            elif self.children[0].var==v:
                return True
            else:
                return False
        elif self.is_and():
            return False
        elif not(self.is_NNF()):
            return False
        elif len(self.children) != 2:
            return False
        else: # remains types OR in NNF with two children
            left, right = self.children[0], self.children[1]
            if not(left.is_decision()) or not(right.is_decision()):
                return False
            elif not(left.children[0].type is Node.Types.VAR) or not(left.children[0].var==v):
                return False
            elif not(right.children[0].type is Node.Types.NEG) or not(right.children[0].children[0].type is Node.Types.VAR) or not(right.children[0].children[0].var==v):
                return False
            else:
                if not(set(left.children[1].vars()).issubset(set(rest))):
                    return False
                elif not(set(right.children[1].vars()).issubset(set(rest))):
                    return False
                elif not(left.children[1].is_OBDD(rest)):
                    return False
                elif not(right.children[1].is_OBDD(rest)):
                    return False
                else:
                    return True

    # Attributes

    def iter(self,first_call=True, post=True):
        """Generator of nodes, post or pre order"""
        if not(self.cache["bit"]):
            self.cache["bit"] = True

            if not(post):
                yield self

            if len(self.children)>0:
                for c in self.children:
                    for n in c.iter(first_call=False, post=post): yield n
            
            if post:
                yield self

        if first_call:
            self.clear_cache(key="bit", iterate=False)

    def clear_cache(self, key, iterate=False):
        if iterate:
            for node in self.iter():
                node.cache[key]=None
        elif key in self.cache and not(self.cache[key] is None):
            self.cache[key]=None
            if self.children:
                for c in self.children:
                    c.clear_cache(key, iterate=False)

    @caching
    def vars(self):
        if (self.is_var()):
            return [self.var]
        elif (self.is_false()) or (self.is_true()):
            return []
        else:
            _vars=[]
            for c in self.children:
                for v in c.vars():
                    if not(v in _vars):
                        _vars.append(v)
        return _vars

    @caching
    def height(self):
        if self.is_leaf():
            return 0
        else:
            return max([c.height() for c in self.children])+1

    @caching
    def size(self, first_call=True):
        if first_call:
            self.clear_cache("size")

        if self.is_leaf():
            return 0
        else:
            return sum([c.size(first_call=False)+1 for c in self.children])

    @caching
    def string(self):
        if self.is_true():
            return "T"
        elif self.is_false():
            return "F"
        elif self.is_var():
            return str(self.var)
        elif (self.is_neg()):
            return r"\neg ({})".format(self.children[0].string())
        elif (self.is_and()):
            return r"\land ({})".format(','.join(map(str,[c.string() for c in self.children])))
        elif (self.is_or()):
            return r"\lor ({})".format(','.join(map(str,[c.string() for c in self.children])))            

    # TRANSFORM
    @caching
    def to_negNNF(self):
        
        if self.is_true():
            return Node(type=Node.Types.F)
        elif self.is_false():
            return Node(type=Node.Types.T)
        elif self.is_var():
            return Node(type=Node.Types.NEG, children=[self])
        elif (self.is_neg()):
            return self.children[0]
        elif (self.is_and()):
            return Node(type=Node.Types.OR, children=[c.to_negNNF() for c in self.children])
        elif (self.is_or()):
            return Node(type=Node.Types.AND, children=[c.to_negNNF() for c in self.children])
    
    @caching
    def to_NNF(self):
        if self.is_NNF():
            return self
        elif (self.is_neg()):
            return self.children[0].to_negNNF()
        elif (self.is_and()):
            return Node(type=Node.Types.AND, children=[c.to_NNF() for c in self.children])
        elif (self.is_or()):
            return Node(type=Node.Types.OR, children=[c.to_NNF() for c in self.children])

    @caching
    def to_DNNF(self):
        if not(self.is_DNNF()):
            raise Exception("Must be a DNNF !")
        elif self.is_leaf():
            return DNNF(type=self.type, children=self.children, id=self.id, var=self.var)
        else:
            children=[c.to_DNNF() for c in self.children]
            return DNNF(type=self.type, children=children, id=self.id, var=self.var)

    @caching
    def condition(self, evidence, first_call=True):
        """
        Conditions the circuit on evidence.

        Inputs:
            - evidence : a dictionary that maps each variable to either 1, 0, -1 if it is true, unspecified or false respectively.
        
        Outputs :
            - node : a circuit node that is equivalent to self | evidence.
        """
        
        if self.is_true() or self.is_false():
            return self
        elif self.is_var():
            if evidence[self.var]==1:
                return Node(type=Node.Types.T)
            elif evidence[self.var]==-1:
                return Node(type=Node.Types.F)
            else:
                return self
        elif self.is_lit() and self.is_neg():
            var=self.children[0].var
            if evidence[var]==1:
                return Node(type=Node.Types.F)
            elif evidence[var]==-1:
                return Node(type=Node.Types.T)
        else:
            new_children = [c.condition(evidence, first_call=False) for c in self.children]
            return Node(type=self.type, children=new_children)

        if first_call:
            self.clear_cache("condition")


    # COMPUTE QUERIES

    @caching
    def accepts(self, states):
        #TODO:Add evidence
        """
        Performs PQE.

        Inputs:
            - states : a torch tensor of states of shape (batch_size x num_variables)
        
        Outputs :
            - p : a torch tensor of shape (batch_size) representing for the probability of the circuit under the independent multi-label distribution parameterized by probs.
        """
        bs=probs.shape[0]

        if self.is_true():
            return torch.ones_like(states, dtype=torch.bool)
        elif self.is_false():
            return torch.zeros_like(states, dtype=torch.bool)
        elif self.is_var():
            return states[:,self.var]
        elif self.is_neg():
                return torch.logical_not(self.children[0].accepts(states))
        elif self.is_and():
            return torch.all(torch.stack([c.accepts(probs) for c in self.children], dim=1), dim=1)
        elif self.is_or():
            return torch.any(torch.stack([c.accepts(probs) for c in self.children], dim=1), dim=1)

    @caching
    def sat(self):
        #TODO:Add evidence
        if self.is_true():
            return True
        elif self.is_false():
            return False
        elif self.is_var():
            return True
        else:
            if not(self.is_DNNF()):
                raise Exception("Only implemented for DNNFs !")
            
            elif self.is_neg():
                return self.children[0].type!="TRUE"
            
            elif self.is_or():
                for c in self.children:
                    if c.sat():
                        return True
                return False
            
            elif self.is_and():
                for c in self.children:
                    if not(c.sat()):
                        return False
                return True

        return state, p

    @caching
    def fuzzy(self, probs):
        """
        Performs fuzzy evaluation.

        Inputs:
            - probs : a torch tensor of probabilities of shape (batch_size x num_variables)
        
        Outputs :
            - f : a torch tensor of shape (batch_size) representing for the fuzzy score of the circuit under probs.
        """
        bs=probs.shape[0]

        if self.is_true():
            return torch.ones(bs)
        elif self.is_false():
            return torch.zeros(bs)
        elif self.is_var():
            return probs[:,self.var]
        elif self.is_neg():
                return 1 - self.children[0].fuzzy(probs)
        elif self.is_and():
                return torch.prod(torch.stack([c.fuzzy(probs) for c in self.children], dim=1), dim=1)
        elif self.is_or():
                return torch.sum(torch.stack([c.fuzzy(probs) for c in self.children], dim=1), dim=1)

    @caching
    def trim(self, F=None, T=None):
        if F is None:
            F=Node(type=Node.Types.F)

        if T is None:
            T=Node(type=Node.Types.T)

        if self.is_leaf():
            return self

        elif self.type is Node.Types.NEG:
            if self.children[0].type is Node.Types.T:
                return F
            elif self.children[0].type is Node.Types.F:
                return T
            else:
                return Node(type=self.type, children=[c.trim(F, T) for c in self.children])

        elif self.type is Node.Types.AND:
            children = [c.trim(F, T) for c in self.children]
            if Node.Types.F in [c.type for c in children]:
                return F
            else:
                children=[c for c in children if c.type is not Node.Types.T]
                if len(children)>1:
                    return Node(type=self.type, children=children)
                elif len(children)==1:
                    return children[0]
                else:
                    return T

        elif self.type is Node.Types.OR:
            children = [c.trim(F, T) for c in self.children]
            if Node.Types.T in [c.type for c in children]:
                return T
            else:
                children=[c for c in children if c.type is not Node.Types.F]
                if len(children)>1:
                    return Node(type=self.type, children=children)
                elif len(children)==1:
                    return children[0]
                else:
                    return F

        else:
            raise Exception("The type of the node must be in : ", list(Node.Types))

    def set_ids(self, post=True):
        for (i, n) in enumerate(self.iter(post=post)):
            n.id = i

    def repr(self, ids=True, children=None):
        if children is None:
            children=self.children

        if self.is_false():
            st = 'F %d' % self.id
        elif self.is_true():
            st = 'T %d' % self.id
        elif self.is_var():
            st = 'V %d %d' % (self.var+1, self.id)
        elif self.is_neg():
            st = 'N %d %d' % (children[0].id, self.id)
        elif self.is_or():
            st_children = " ".join( '%d' % c.id for c in children)
            st = 'O %s %d' % (st_children, self.id)
        elif self.is_and():
            st_children = " ".join( '%d' % c.id for c in children)
            st = 'A %s %d' % (st_children, self.id)
        
        if ids:
            return st
        else:
            return " ".join(st.split()[:-1])

    def reduction(self):
        # T, F = Node(type=Node.Types.T), Node(type=Node.Types.F)
        uniques={}
        self.set_ids()
        for node in list(self.iter()):
            if node.is_leaf():
                r=node.repr(ids=False)
            else:
                r=node.repr(ids=False, children=[c.unique for c in node.children])
            if r in uniques:
                node.unique=uniques[r]
            else:
                if node.is_leaf():
                    node.unique=node
                else:
                    node.unique=Node(type=node.type, id=node.id, children=[c.unique for c in node.children])
                uniques[r]=node.unique

        return self.unique


    def circuit(self):
        #TODO
        return None

    def squeeze(self):
        if self.is_binary():
            new_children=[]
            for c in self.children:
                if c.type == self.type:
                    for gc in c.squeeze().children:
                        new_children.append(gc)
                else:
                    new_children.append(c)

        return Node(type=self.type, children=new_children)

    def primal_graph(self):
        import networkx as nx
        G = nx.Graph()
        G.add_nodes_from(self.vars())
        if self.is_and():
            for c in self.squeeze().children:
                edges_to_add = [(u, v) for u in c.vars() for v in c.vars() if u != v]
                G.add_edges_from(edges_to_add)

        else:
            edges_to_add = [(u, v) for u in self.vars() for v in self.vars() if u != v]
            G.add_edges_from(edges_to_add)

        return G

