import sys
sys.path.append("..")

import torch
import networkx as nx

from circuits.node import Node
from circuits.utils import caching, logsumexp, logsum, log1mexp, ind_topk

class DNNF(Node):
    def __init__(self, type, children=None, id=None, var=None):
        if children:
            for c in children:
                if not(isinstance(c, DNNF)):
                    raise Exception("Children must be DNNFs !")

        super().__init__(type=type, children=children, id=id, var=var)

        if not(self.is_DNNF()):
            raise Exception("Must be a DNNF !")

    @classmethod
    def from_node(cls, node):
        if node.is_leaf():
            if node.is_var():
                return DNNF(type=node.type, children=node.children, id=node.id, var=node.var)
            else:
                return DNNF(type=node.type, children=node.children, id=node.id)
        else:
            children=[c if isinstance(c, DNNF) else DNNF.from_node(c) for c in node.children]
            return DNNF(type=node.type, children=children, id=node.id)

    @caching
    def trim(self, F=None, T=None):
        if F is None:
            F=Node(type=Node.Types.F)

        if T is None:
            T=Node(type=Node.Types.T)

        if not(self.sat()):
            node=F

        elif self.to_negNNF().is_DNNF() and not(self.to_negNNF().sat()):
            node=T

        else:
            node=Node(type=self.type, children=[c.trim(F, T) for c in self.children])

        return node.trim(F, T)

    def mpe(self, probs, first_call=True, boolean=False):
        bs=probs.shape[0]
        default=torch.where(probs.ge(0.5), 1, -1)
        literals=torch.where(probs.ge(0.5), probs, 1-probs)
        
        for node in self.iter():
            if node.is_true():
                p=torch.ones(bs, dtype=torch.float64)
                states=torch.zeros_like(probs, dtype=torch.int8)
            elif node.is_false():
                p=torch.zeros(bs, dtype=torch.float64)
                states=torch.zeros_like(probs, dtype=torch.int8)
            elif node.is_var():
                states=torch.zeros_like(probs, dtype=torch.int8)
                states[:,node.var] = 1
                p=probs[:,node.var]
                
            elif node.is_neg():
                statesc, pc = node.children[0].cache["mpe"]
                states, p = statesc.mul(-1), 1-pc
            
            elif node.is_and():
                statesc = torch.stack([c.cache["mpe"][0] for c in node.children], dim=2)
                pc = torch.stack([c.cache["mpe"][1] for c in node.children], dim=1)

                p = torch.prod(pc, dim=1)
                # if not(torch.all(torch.prod(statesc, dim=2)==0)):
                #     print(statesc)
                #     print(node.string())
                #     print("This node is decomposable : ", node.is_DNNF())
                #     raise Exception("AND nodes must be decomposable")
                states = torch.sum(statesc, dim=2)

            elif node.is_or():
                mask=torch.zeros_like(probs).bool()
                mask[:,node.vars()]=True
                statesc = torch.stack([torch.where(torch.logical_and(mask, c.cache["mpe"][0]==0), default, c.cache["mpe"][0]) for c in node.children if not(c.is_false())], dim=2)
                pc = torch.stack([torch.mul(c.cache["mpe"][1], torch.prod(torch.where(torch.logical_and(mask, c.cache["mpe"][0]==0), literals, 1), dim=1)) for c in node.children if not(c.is_false())], dim=1)

                p, idx = torch.max(pc, dim=1)
                idx=idx.unsqueeze(1).unsqueeze(2).expand(-1, statesc.shape[1], 1)
                states = torch.gather(statesc, dim=2, index=idx).squeeze(2)
                # states = torch.stack([statesc[i, :, idx[i]] for i in range(bs)], dim=0)

            node.cache["mpe"]=(states, p)

        states, p = self.cache["mpe"]

        if boolean:
            states = states.ge(0)
        
        if first_call:
            self.clear_cache("mpe")
        # print("MPE of {} : ".format(self.string()), state, p)
        return states, p

    def mar_mpe(self, probs, mar, first_call=True, boolean=False):
        bs=probs.shape[0]
        m=len(mar)
        var2mar = {v:m for (m, v) in iterate(mar)}
        default=torch.where(probs.ge(0.5), 1, -1)
        literals=torch.where(probs.ge(0.5), probs, 1-probs)
        
        for node in self.iter():
            if node.is_true():
                p=torch.ones(bs, dtype=torch.float64)
                states=torch.zeros((bs, m), dtype=torch.int8)
            elif node.is_false():
                p=torch.zeros(bs, dtype=torch.float64)
                states=torch.zeros((bs, m), dtype=torch.int8)
            elif node.is_var():
                states=torch.zeros((bs, m), dtype=torch.int8)
                if node.var in mar:
                    states[:,var2mar[node.var]] = 1
                    p=probs[:,node.var]
                else:
                    p=torch.ones(bs, dtype=torch.float64)
                
            elif node.is_neg():
                statesc, pc = node.children[0].cache["mpe"]
                if node.children[0].var in mar:
                    states, p = statesc.mul(-1), 1-pc
                else:
                    states, p = statesc, pc
            
            elif node.is_and():
                statesc = torch.stack([c.cache["mpe"][0] for c in node.children], dim=2)
                pc = torch.stack([c.cache["mpe"][1] for c in node.children], dim=1)

                p = torch.prod(pc, dim=1)
                # if not(torch.all(torch.prod(statesc, dim=2)==0)):
                #     print(statesc)
                #     print(node.string())
                #     print("This node is decomposable : ", node.is_DNNF())
                #     raise Exception("AND nodes must be decomposable")
                states = torch.sum(statesc, dim=2)

            elif node.is_or():
                mask=torch.zeros_like(probs).bool()
                mask[:,list(set(node.vars()) & set(mar))]=True
                statesc = torch.stack([torch.where(torch.logical_and(mask, c.cache["mpe"][0]==0), default, c.cache["mpe"][0]) for c in node.children if not(c.is_false())], dim=2)
                pc = torch.stack([torch.mul(c.cache["mpe"][1], torch.prod(torch.where(torch.logical_and(mask, c.cache["mpe"][0]==0), literals, 1), dim=1)) for c in node.children if not(c.is_false())], dim=1)

                p, idx = torch.max(pc, dim=1)
                idx=idx.unsqueeze(1).unsqueeze(2).expand(-1, statesc.shape[1], 1)
                states = torch.gather(statesc, dim=2, index=idx).squeeze(2)
                # states = torch.stack([statesc[i, :, idx[i]] for i in range(bs)], dim=0)

            node.cache["mpe"]=(states, p)

        states, p = self.cache["mpe"]

        if boolean:
            states = states.ge(0)
        
        if first_call:
            self.clear_cache("mpe")
        # print("MPE of {} : ".format(self.string()), state, p)
        return states, p

    def topk(self, probs, k=2, first_call=True, boolean=False):
        bs, n = probs.shape
        default=torch.where(probs.ge(0.5), 1, -1).unsqueeze(-1)
        literals=torch.where(probs.ge(0.5), probs, 1-probs).unsqueeze(-1)
        
        for node in self.iter():
            if node.is_true():
                p=torch.ones((bs, 1), dtype=torch.float64)
                states=torch.zeros_like(probs, dtype=torch.int8).unsqueeze(-1)
            elif node.is_false():
                p=torch.zeros((bs, 1), dtype=torch.float64)
                states=torch.zeros_like(probs, dtype=torch.int8).unsqueeze(-1)
            elif node.is_var():
                states=torch.zeros_like(probs, dtype=torch.int8).unsqueeze(-1)
                states[:,node.var,:] = 1
                p=probs[:,node.var].unsqueeze(-1)
                
            elif node.is_neg():
                statesc, pc = node.children[0].cache["topk"]
                states, p = statesc.mul(-1), 1-pc
            
            else:
                if node.is_and():
                    statesc, pc = node.children[0].cache["topk"]
                    for c in node.children[1:]:
                        statesn, pn = c.cache["topk"]
                        statesc = torch.add(statesc.expand((-1, -1, statesc.shape[2]*statesn.shape[2])), statesn.tile((1, 1, statesc.shape[2])))
                        pc = torch.mul(pc.expand((-1, pc.shape[1]*pn.shape[1])), pn.tile((1, pc.shape[1])))
                        

                elif node.is_or():
                    #TODO: delete non-unique states in statesc
                    c=node.children[0]
                    statesc, pc = c.cache["topk"]
                    tofix = list(set(node.vars()) - set(c.vars()))
                    if len(tofix)!=0:
                        fix_states, fix_p = ind_topk(probs, k=k, variables=tofix)
                        statesc = torch.add(statesc.expand((-1, -1, statesc.shape[2]*fix_states.shape[2])), fix_states.tile((1, 1, statesc.shape[2])))
                        pc = torch.mul(pc.expand((-1, pc.shape[1]*fix_p.shape[1])), fix_p.tile((1, pc.shape[1])))

                    for c in node.children[1:]:
                        statesn, pn = c.cache["topk"]
                        tofix = list(set(node.vars()) - set(c.vars()))
                        if len(tofix)!=0:
                            fix_states, fix_p = ind_topk(probs, k=k, variables=tofix)
                            statesn = torch.add(statesn.expand((-1, -1, statesn.shape[2]*fix_states.shape[2])), fix_states.tile((1, 1, statesn.shape[2])))
                            pn = torch.mul(pn.expand((-1, pn.shape[1]*fix_p.shape[1])), fix_p.tile((1, pn.shape[1])))
                        
                        statesc = torch.cat([statesc, statesn], dim=2)
                        pc = torch.cat([pc, pn], dim=1)

                if pc.shape[1] > k:
                    p, idx = torch.topk(pc, k=k, dim=1)
                    idx=idx.unsqueeze(1).expand(-1, n, k)
                    states = torch.gather(statesc, dim=2, index=idx)
                else:
                    p, states = pc, statesc

            node.cache["topk"]=(states, p)

        states, p = self.cache["topk"]

        if boolean:
            states = states.ge(0)
        
        if first_call:
            self.clear_cache("topk")
        # print("MPE of {} : ".format(self.string()), state, p)
        return states, p