import numpy as np
import networkx as nx
from itertools import compress
import copy

from .utils import *

class JunctionTree():
    def __init__(self, Ehs, Ees, numV=None):
        
        self.numV = numV
        self.widthJT = numV

        self.cliques = None
        self.assignedVar = None
        self.buildJunctionGraph(Ehs, Ees)

        self.numC = len(self.cliques)

        self.cliqParents = None
        self.cliqChildren = None
        self.cliqNei = None
        self.order = None
        self.buildJunctionTree()


    def minNei(self, Ea):
        #Generate a variable elimination sequence with min nei heuristic
        unvisited = np.ones(self.numV, dtype=bool)
        elimSeq = []
        # elimSeq =np.zeros((numV, 1), dtype=np.int)
        while len(elimSeq)<self.numV:
            numNei = np.sum(Ea, axis=0) # Compute the number of neighbors for each node
            minNode = np.min(numNei[unvisited]) # Find the minimum of neighbors
            argmin = (numNei == minNode) # Find the nodes with minimal number of neighbors
            minFill = np.logical_and(argmin, unvisited) # Restrict to unvisited nodes
            # Delete the edges from or to these nodes
            Ea[minFill,:]=False
            Ea[:,minFill]=False
            # Add the nodes to the elim order
            elimSeq += np.where(minFill)[0].tolist()
            # Mark the nodes visited:
            unvisited = np.logical_and(unvisited, np.logical_not(minFill))

        return elimSeq

    def minFill(self, Ea):
        #Generate a variable elimination sequence with min fill heuristic
        unvisited = np.ones(self.numV, dtype=bool)
        elimSeq = []
        while len(elimSeq)<self.numV:
            smallest = 0
            min_fill = self.numV
            for i in range(self.numV):
                if unvisited[i]:
                    nei = np.nonzero(Ea[i])
                    size_nei = len(nei)
                    nb_fill = size_nei**2 - size_nei - np.sum(Ea[nei, nei])
                    if nb_fill < min_fill:
                        smallest = i
                        min_fill = nb_fill

            elimSeq.append(smallest)
            unvisited[smallest] = False

        return elimSeq

    def buildJunctionGraph(self, Ehs, Ees):

        Ea = Ehs + np.transpose(Ehs) + Ees

        elimSeq = self.minNei(Ea.copy())

        # Eliminate nodes to get elimination cliques
        cliques = []
        widthJT = 0
        for i, v in enumerate(elimSeq):
        # for v in elimSeq:

            vNei=np.where(Ea[v])[0].tolist() # Find neighbors of v
            numVN=len(vNei) # Count neighbors
            # assert((numVN>=1) or (i==self.numV - 1)) # Check v connected

            cliques.append(vNei+[v]) # Add the clique (neighborhood and v) to the list
    
            widthJT = max(widthJT, numVN) # Update tree width
            Ea[np.ix_(vNei, vNei)] = np.logical_not(np.eye(numVN, dtype=bool)) # Complete the clique
            Ea[:, v] = 0 # Disconnect v
            Ea[v, :] = 0

        maxCliques = findMaxCliques(cliques)

        #Record in which cliques each variable appears in, so that it can be marginalized efficiently.
        nMC = len(maxCliques)
        # var2cliq=[None]*self.numV
        var2cliq=[[] for i in range(self.numV)]

        for i in range(nMC):
            c=maxCliques[i]
            for v in c:
                var2cliq[v].append(i)
                # if var2cliq[v]==None:
                #     var2cliq[v]=[i]
                # else:
                #     var2cliq[v].append(i)

        # Record number of apparitions
        numCliqbyVar = [len(cs) for cs in var2cliq]
        assign2cliq = [v[0] for v in var2cliq]

        assignedVar = [[] for i in range(nMC)]
        for (i,c) in enumerate(assign2cliq):
            assignedVar[c].append(i)
        
        self.cliques = maxCliques
        self.numC = nMC
        self.assignedVar = assignedVar
        self.widthJT = widthJT

        return maxCliques, assignedVar

    def buildJunctionTree(self):

        edgesJG = np.zeros((0, 3), dtype=np.int64)
        for c1 in range(self.numC - 1):
            for c2 in range(c1 + 1, self.numC):
                #the weight of each edge is the variable number after intersection
                weight = len(set(self.cliques[c1]).intersection(set(self.cliques[c2])))

                #if weight > 0: #[np.concatenate((tuple(var2cliq[v])[0], [c]), axis=0)]
                    #print("weight",weight)
                row = np.asarray([c1, c2, weight])
                edgesJG = np.vstack((edgesJG, row))

        # print("edgesJG",edgesJG)
        Ej, Wj = kruskal(edgesJG, self.numC)
        #cliqParents records each clique's parent clique index (-1 for root)
        cliqParents = -1*np.ones(self.numC, dtype=np.int64) # intialize at -1 for all

        #childvar2cliq records each clique's children clique indices (empty  for leaf)
        childvar2cliq = []
        for i in range(self.numC):
            childvar2cliq.append([])

        #the clique indices in up passage pass sequence
        order = []

        #Select an arbitrary clique as root
        cliqParents, cliqChildren, order = visitTree(Ej, 0, -1, cliqParents, childvar2cliq, order)
        # print(order)
        assert(len(order)==self.numC)
        order.reverse()

        self.cliqParents = cliqParents
        self.cliqChildren = cliqChildren
        self.order = order

        self.cliqNei = []
        for i in range(self.numC):
            self.cliqNei.append(self.getCliqNei(i))
        
        return cliqParents, cliqChildren, order

    def checkRIP(self):
        history = []
        visited = []
        for i in reversed(self.order):
            c1 = self.cliques[i]
            if len(visited)>0:
                checks = False
                hist_inter = set(c1).intersection(set(history))
                for j in visited:
                    c2 = self.cliques[j]
                    cliq_inter = set(c1).intersection(set(c2))
                    if hist_inter == cliq_inter:
                        checks=True
                        break

                if not(checks):
                    return False

            visited.append(i)
            for v in c1:
                if not(v in history):
                    history.append(v)

        return True

    def getCliqNei(self, i):
        if self.cliqParents[i]==-1:
            cliqNei = self.cliqChildren[i]
        else:
            cliqNei = self.cliqChildren[i] + [self.cliqParents[i]]

        return cliqNei