import numpy as np
import networkx as nx
from itertools import compress
import copy

from .junctionTree import JunctionTree
from .utils import *

# reference for this class can be found on :
# https://github.com/ronghanghu/mhex_graph
# https://github.com/kylemin/HEX-graph
# https://github.com/yuweitu/MultiLabel_Classification


def compute_exclusive(Ehs, Ehd):
    numV = Ehd.shape[0]
    Ees = np.zeros((numV, numV), dtype=bool)
    Eed = np.zeros((numV, numV), dtype=bool)
    for i in range(numV-1):
        ch_i = np.where(Ehd[i, :])[0].tolist()
        anc_i = np.where(Ehd[:, i])[0].tolist()
        par_i = np.where(Ehs[:, i])[0].tolist()
        # cl_i = ch_i + [i]
        for j in range(i+1, numV): # because Ee is symmetric matrix
            if not(j in anc_i): # useless to check for ancestors of i
                ch_j = np.where(Ehd[j, :])[0].tolist()
                anc_j = np.where(Ehd[:, j])[0].tolist()
                par_j = np.where(Ehs[:, j])[0].tolist()
                # cl_j = ch_j + [j]
                if not(i in anc_j): # useless to check for descendants of i
                    inter = np.intersect1d(ch_i, ch_j, return_indices=False)
                    if len(inter)==0:
                        # Add to dense exclusion matrix
                        Eed[i, j] = True
                        Eed[j, i] = True

                        addSparse=True
                        for pi in par_i:
                            if not(pi in anc_j):
                                addSparse=False
                        for pj in par_j:
                            if not(pj in anc_i):
                                addSparse=False
                            
                        if addSparse:
                            # Add to sparse exclusion matrix
                            Ees[i, j] = True
                            Ees[j, i] = True

    return Ees, Eed


class HEXGraph(object):
    def __init__(self, Eh=None, Ee=None, T=None, Ls=None, L=None, file_path=None, nodenames=None, assume_exclusive=True):

        """
        Initialize the HEXGraph given the hierarchical and exclusion matrices
        (or a Directed Graph or a Lattice)

        Inputs:
            - Eh : hierarchical matrix, Eh[i, j] = 1 if j is included in i.
            - Ee : exclusion matrix, Eh[i, j] = 1 if i and j are mutually exclusive
            - T : directed graph describing the tree hierarchy of classes,
                ie. there's an edge (i, j) if j is included in i.
                (classes are assumed mutually exclusive whenever possible)
            - Ls : directed graph describing the lattice hierarchy of classes,
                ie. there's an edge (i, j) if j is included in i.
                (classes are assumed mutually exclusive whenever possible)
            - L : a larger lattice (Ls must be included into L) to extract exclusion relations from.
            - nodenames : names given to the nodes representing classes.
        """

        self.nodenames = nodenames

        self.sparsifiedDensified=False
        self.Ehs=None
        self.Ees=None
        self.Ehd=None
        self.Eed=None

        self.jt=None

        if not((Eh is None) or (Ee is None)):
            self.numV = Eh.shape[0]
            if Eh.shape[1] != self.numV:
                raise Exception('Invalid hierarchical graph size')
            else:
                self.Eh = Eh #adjacency matrix of hierarchic graph
            if (Ee.shape[1] != self.numV) or (Ee.shape[0] != self.numV):
                raise Exception('Invalid exclusion graph size')
            else:
                self.Ee = Ee #adjacency matrix of exclusion graph

        elif not(T is None):
            self.initFromTree(T)
        
        elif not(Ls is None):
            self.initFromLattice(Ls, L, assume_exclusive)

        elif not(file_path is None):
            self.loadFromFile(file_path)

        else:
            raise Exception('No input given')

        if not(self.sparsifiedDensified):
            self.buildSparseDense()

        self.states = np.logical_or(self.Ehd, np.eye(self.numV))
        self.leaves_mask = np.logical_not(np.any(self.Ehd, axis=1))

    def checkConsistency(self):
        # Check types of adjacency matrices 
        if self.Eh.dtype != bool or self.Ee.dtype != bool:
            return False
        # Check sizes of adjacency matrices 
        size = self.Eh.shape[0]
        if self.Eh.shape[1] != size:
            return False
        if self.Ee.shape[0] != size:
            return False
        if self.Ee.shape[1] != size:
            return False
        # Check for self loops
        if np.trace(self.Eh)>0 or np.trace(self.Ee)>0:
            return False
        # Check that Ee is symetric
        if not np.array_equal(self.Ee, np.transpose(self.Ee)):
            return False
        
        # Build the asymptotic transition matrix of Eh
        aEh = computeAsymptoticTransitionMatrix(self.Eh, self.numV)
        # Check that Eh has no directed loops
        if np.trace(aEh)>0:
            return False
        # Check that Eh is connected
        # undir_adj = np.logical_xor(aEh, np.transpose(aEh))
        # connected = np.min(np.max(undir_adj, axis=0))
        # if not connected:
        #     print("not connected")
        #     return False

        # Check no exclusion edges between two ancestors (or itself and an ancestor)
        ancestors = np.array(np.add(np.transpose(aEh), np.eye(self.numV, dtype=bool)))
        for i in range(self.numV):
            if np.sum(self.Ee[ancestors[i], ancestors[i]]) != 0:
                return False

        return True

    def initFromLattice(self, Ls, L=None, assume_exclusive=True):
        """
        Creates the HEXgraph (ie. Ehs, Ehd, Ees, Eed) from a lattice that represents hierarchical relations

        Inputs:
            - Ls : a lattice such that an edge (i, j) exists if class i subsumes class j.
            - L : a larger lattice (Ls must be included into L) to extract exclusion relations from. 
            - assume_exclusive : a bool that determines if classes are supposed to be exclusive of one another by default (ie. they will be exclusive if they share no subclass).
        
        Outputs :
            - (Ehs, Ehd, Ees, Eed)
        """
        # Init number of nodes
        self.numV = len(Ls.nodes)
        nodelist = list(Ls.nodes)

        # We assume that the directed graph given is already sparse
        Ehs = nx.adjacency_matrix(Ls).astype(bool)
        sparse_eye = coo_array(np.eye(self.numV, dtype=bool))
        Ehd = computeSparseAsymptoticTransitionMatrix(Ehs.copy(), sparse_eye)

        self.Ehs=np.array(Ehs.todense())
        self.Ehd=np.array(Ehd.todense())
        
        if assume_exclusive:
            if L is None:
                Ees, Eed = compute_exclusive(self.Ehs, self.Ehd)
            else:
                for n in L.nodes:
                    if not (n in nodelist):
                        nodelist.append(n)
                LEhs = nx.adjacency_matrix(L, nodelist).astype(bool)
                sparse_eye = coo_array(np.eye(len(L.nodes), dtype=bool))
                LEhd = computeSparseAsymptoticTransitionMatrix(LEhs.copy(), sparse_eye)
                LEhs=np.array(LEhs.todense())
                LEhd=np.array(LEhd.todense())
                LEes, LEed = compute_exclusive(LEhs, LEhd)
                Ees = LEes[0:self.numV, 0:self.numV]
                Eed = LEed[0:self.numV, 0:self.numV]

        else:
            Ees = np.zeros((self.numV, self.numV), dtype=bool)
            Eed = np.zeros((self.numV, self.numV), dtype=bool)

        self.sparsifiedDensified=True
        self.Ees=Ees
        self.Eed=Eed

        return Ehs, Ees, Ehd, Eed

    def initFromTree(self, T):
        M = nx.moral_graph(T.reverse())

        # Get the adjacency matrices of both the hierarchical graph and its moral graph
        Eh = nx.adjacency_matrix(T).astype(bool)
        Em = nx.adjacency_matrix(M).astype(bool)

        # Compute the undirected version of the hierarchical graph
        Eht = np.transpose(Eh)
        Eu = (Eh + Eht).astype(bool)

        # Compute the exclusion graph from the moral graph minus the undirected graph
        Ee = (Em - Eu).astype(bool)

        self.numV = len(T.nodes)
        if self.nodenames is None:
            self.nodenames = list(T.nodes)
        
        self.Eh = Eh.todense()
        self.Ee = Ee.todense()
        return Eht, Ee

    def loadFromFile(self, file_path):
        self.Ehs = np.loadtxt(file_path+".Ehs")
        self.Ees = np.loadtxt(file_path+".Ees")
        self.Ehd = np.loadtxt(file_path+".Ehd")
        self.Eed = np.loadtxt(file_path+".Eed")
        self.numV = self.Ehs.shape[0]
        self.sparsifiedDensified = True

    def buildSparseDense(self):

        # dense copy
        Ehd = np.copy(self.Eh) #shallow copy
        Eed = np.copy(self.Ee)
        # sparse copy
        Ehs = np.copy(self.Eh)
        Ees = np.copy(self.Ee)

        print("hierchical edges...")
        # first sparsify/densify hierchical edges
        for i in range(self.numV):
            desc = self.findDescents(i, min_dergee=2)
            for j in desc:
                if self.Eh[i, j]: 
                    Ehs[i, j] = False #Sparsify
                else:
                    Ehd[i, j] = True #Densify

        print("... then exclusion edges !")
        # then exclusion edges
        for i in range(self.numV-1):
            # anc_i = self.findAncestors(i)
            anc_i = np.where(Ehd[:, i])[0].tolist()
            cl_i = anc_i + [i]
            for j in range(i+1, self.numV): # because Ee is symmetric matrix
                if not(j in anc_i): # useless to check for ancestors of i
                    # anc_j = self.findAncestors(j)
                    anc_j = np.where(Ehd[:, j])[0].tolist()
                    cl_j = anc_j + [j]
                    if not(i in anc_j): # useless to check for descendants of i
                        # if sum(self.Ee[anc_i, anc_j])> 1:
                        if (np.sum(self.Ee[np.ix_(cl_i, anc_j)])>0) or (np.sum(self.Ee[np.ix_(anc_i, cl_j)])>0):
                            if self.Ee[i, j]:
                                #Sparsify
                                Ees[i, j] = False
                                Ees[j, i] = False
                            else:
                                #Densify
                                Eed[i, j] = True
                                Eed[j, i] = True
        
        self.sparsifiedDensified=True
        self.Ehs=Ehs
        self.Ees=Ees
        self.Ehd=Ehd
        self.Eed=Eed

        return Ehs, Ees, Ehd, Eed

    def buildJT(self):
        if self.sparsifiedDensified==False:
            Ehs, Ees, Ehd, Eed = self.buildSparseDense()
        else:
            Ehs = self.Ehs
            Ees = self.Ees

        self.jt = JunctionTree(Ehs, Ees, self.numV)

    def listStatesSpace(self, cliques=None):
        if cliques is None:
            if self.jt is None:
                self.buildJT()
            cliques = self.jt.cliques
        
        numC = self.jt.numC
        statesSpace = np.empty(numC, dtype=object)
        for i, c in enumerate(cliques):
            l = len(c)
            unfixed = np.ones(l, dtype=bool)
            init_states = np.zeros((1, l), dtype=bool)
            Chd = self.Ehd[np.ix_(c, c)]
            Ced = self.Eed[np.ix_(c, c)]
            statesSpace[i] = listCliqueStates(Chd, Ced, unfixed, init_states)

        return statesSpace

    def recordSumProduct(self, cliques=None):
        # sp[i][j][k] represent the states of clique i that match on their common variables the kth state of the jth neighbor of clique i

        if cliques is None:
            if self.jt is None:
                self.buildJT()
            cliques = self.jt.cliques

        statesSpace = self.listStatesSpace(cliques)
        sp = np.empty(self.jt.numC, dtype=object)
        for (i, c) in enumerate(cliques):
            
            nei = self.jt.cliqNei[i]
            c_states = statesSpace[i]
            # num_c_states = c_states.shape[0]
            statesNei = np.empty(len(nei), dtype=object)
            # For each clique in the neighborhood...
            for (j, cn_index) in enumerate(nei):
                cn = cliques[cn_index]
                # compute the intersection (or separator) of the cliques
                inter_var, inter_c_ind, inter_cn_ind = np.intersect1d(c, cn, return_indices=True)
                if len(inter_var)>0:
                    c_inter_states = c_states[:, inter_c_ind]
                    cn_inter_states = statesSpace[cn_index][:, inter_cn_ind]
                    num_cn_states = cn_inter_states.shape[0]

                    # initialize the sumproduct for the cn clique
                    statesNei[j] = np.empty(num_cn_states, dtype=object)
                    
                    for s in range(num_cn_states):
                        state = cn_inter_states[s]
                        linkedStates = np.min(c_inter_states==state, axis=1)
                        # for each state of the c clique, get the cn linked states
                        statesNei[j][s]=np.nonzero(linkedStates)[0]
                else:
                    num_c_states = c_states.shape[0]
                    num_cn_states = statesSpace[cn_index].shape[0]

                    # initialize the sumproduct for the cn clique
                    statesNei[j] = np.empty(num_cn_states, dtype=object)
                    
                    for s in range(num_cn_states):
                        # all states are linked
                        statesNei[j][s]=np.arange(stop=num_c_states, dtype=int)
                    
            sp[i]=statesNei

        assert(self.checkSP(sp))
        return sp

    def findDescents(self, i, min_dergee=1):
        desc = []
        for j in range(self.numV):
            if self.Eh[i, j]:
                if min_dergee == 1:
                    desc.append(j)
                    desc += self.findDescents(j, 1)
                else:
                    desc += self.findDescents(j, min_dergee-1)

        return desc

    def findAncestors(self, i, min_dergee=1):
        ancest = []
        for h in range(self.numV):
            if self.Eh[h, i]:
                if min_dergee == 1:
                    ancest.append(h)
                    ancest += self.findAncestors(h, 1)
                else:
                    ancest += self.findAncestors(h, min_dergee-1)

        return ancest

    def getStates(self, idx_nodes):
        return self.states[:, idx_nodes].astype('bool')

    def getTM(self):
        id_mat = np.identity(self.numV)
        return np.logical_or(self.Ehd, id_mat)

    def checkValidity(self, states, reduce=True):

        if states.shape[0]>1024:
            nb_split = states.shape[0] // 1024
            states_splits = torch.tensor_split(states, nb_split, dim=0)
            list_valid = []
            for split in states_splits:
                list_valid.append(self.checkValidity(split, reduce=reduce))

            if reduce:
                valid = torch.tensor(list_valid)
                return torch.all(valid).item()
            else:
                valid = torch.cat(list_valid)
                return valid

        else:
            s = states.unsqueeze(2).to(torch.float)
            overlap_mat = torch.matmul(s, s.transpose(1, 2))
            checkE = torch.all(torch.flatten(torch.logical_not(torch.logical_and(torch.from_numpy(self.Eed), overlap_mat)), start_dim=1), dim=1)

            tEhd = torch.from_numpy(self.Ehd).transpose(0, 1)
            checkH = torch.all(torch.logical_or(torch.logical_not(states), torch.all(torch.eq(tEhd, torch.logical_and(tEhd, s.transpose(1, 2))), dim=2)), dim=1)

            valid = torch.logical_and(checkE, checkH)

            if reduce:
                return torch.all(valid).item()
            else:
                return valid

    def checkSP(self, sp=None, verbose=True):
        """
        Checks that the sumproduct is non-empty for every pair of adjacent nodes and each state of that pair.

        Inputs:
            - sp : a pre-computed sum-product to be checked
        
        Outputs :
            - valid : True is no value is missing and False otherwise
        """
        valid=True
        if sp is None:
            sp = self.recordSumProduct()

        for i in range(self.jt.numC):
            for j in range(sp[i].shape[0]):
                l=len(sp[i][j])
                if l==0:
                    valid=False
                    if verbose:
                        print("Empty state space clique n°{}, neighbor n°{}".format(i, j))
                for k in range(l):
                    if len(sp[i][j][k])==0:
                        valid=False
                        if verbose:
                            print("Empty sum-product clique n°{}, neighbor n°{}, state n°{}".format(i, j, k))

        return valid

    def save(self, file_path):
        np.savetxt(file_path+".Ehs", np.array(self.Ehs))
        np.savetxt(file_path+".Ees", np.array(self.Ees))
        np.savetxt(file_path+".Ehd", np.array(self.Ehd))
        np.savetxt(file_path+".Eed", np.array(self.Eed))
