import numpy as np
import torch
from itertools import compress
from scipy.sparse import coo_array

# SYMBOLIC UTILS

def computeSparseAsymptoticTransitionMatrix(m, id=None):

    """
    Compute the sparse transitive closure of a graph from sparse adjacency matrix,
    ie. there is a an edge from a to be in tm 
    if there is a directed path from a to b in m

    Inputs:
        - m : sparse adjacency matrix of the graph
        - id : sparse identity matrix of the same size as m
    
    Outputs :
        - tm : sparse transitive closure of m
    """

    if id==None:
        s = m.shape[0]
        id = coo_array(np.eye(s, dtype=bool))

    tm = (id + m) @ m
    if (m-tm).nnz == 0:
        return m
    else:
        return computeSparseAsymptoticTransitionMatrix(tm, id)

def computeAsymptoticTransitionMatrix(m, s=None):

    """
    Compute the transitive closure of a graph from adjacency matrix,
    ie. there is a an edge from a to be in tm 
    if there is a directed path from a to b in m

    Inputs:
        - m : adjacency matrix of the graph
        - id : identity matrix of the same size as m
    
    Outputs :
        - tm : transitive closure of m
    """

    if s==None:
        s = m.shape[0]
    id = np.eye(s, dtype=bool)
    tm = np.matmul((np.add(id,m)), m)
    if np.array_equal(m,tm):
        return m
    else:
        return computeAsymptoticTransitionMatrix(tm, s)

def findMaxCliques(cliques):

    """
    Find the set of maximal cliques in a set of cliques.
    The maximal cliques are those that are not included in any other clique of the set.

    Inputs:
        - cliques : a set of cliques
                (ie. lists of ints corresponding to nodes of a graph)
    
    Outputs :
        - maxCliques : the set of maximal cliques in cliques
    """

    l = len(cliques)
    isMaxCliques =[True]*l

    for i in range(l):
        if not(isMaxCliques[i]):
            continue
        ci = cliques[i]
        li = len(ci)
        for j in range(i+1, l):
            if not(isMaxCliques[i]):
                break
            elif isMaxCliques[j]:
                cj = cliques[j]
                if li<len(cj):
                    included=True
                    for n in ci:
                        if not(n in cj):
                            included=False
                            break
                    isMaxCliques[i]=not(included)
                else:
                    included=True
                    for n in cj:
                        if not(n in ci):
                            included=False
                            break
                    isMaxCliques[j]=not(included)

    maxCliques = list(compress(cliques, isMaxCliques))
    return maxCliques

def createsCycle(korif, akmi):

    """
    Test whether there will be a circle if a new edge is added

    Inputs:
        - korif : list of length num_vertices : associates a cluster id to each vertex
        - akmi : 2-tuple : the edge we test for
    
    Outputs :
        - korif : the updated list of clusters
        - 1 if it creates a cycle, 0 else
    """

    g=max(korif)+1
    n=len(korif)

    if korif[akmi[0]]==0 and korif[akmi[1]]==0: # Test if both vertices are not in a cluster yet,
        korif[akmi[0]]=g # then put them in cluster g
        korif[akmi[1]]=g

    elif korif[akmi[0]]==0: # Test if first vertex is not in a cluster yet
        korif[akmi[0]]=korif[akmi[1]] # then put in in cluster of second vertex

    elif korif[akmi[1]]==0: # Test if second vertex is not in a cluster yet
        korif[akmi[1]]=korif[akmi[0]] # then put in in cluster of first verte

    elif korif[akmi[0]]==korif[akmi[1]]: # If both in clusters, check if it's the same cluster
        return korif, True # return same clusters allocations and 1 for cycle created

    else: # If both are in different clusters, fusion the clusters into the min
        mx=max(korif[akmi[0]], korif[akmi[1]]) # take the max cluster
        mn=min(korif[akmi[0]], korif[akmi[1]]) # and fuse into the min
        for i in range(n): 
            if korif[i]==mx:
                korif[i]=mn
                                   
    return korif, False

def kruskal(PV, numV):

    """
    Kruskal algorithm for finding maximum spanning tree
    @params:
    PV is nx3 martix. 1st and 2nd number's define the edge (2 vertices) and the 3rd is the edge's weight.
    numV is number of vertices
    @returns:
    Et is adjacency matrix of maximum spanning tree
    w is maximum spanning tree's weight
    """

    Et = np.zeros((numV, numV), dtype=bool)
    num_edge = PV.shape[0] 
    if num_edge == 0:
        W = 0
        return Et,W
    #sort PV by descending weights order.
    PV = PV[np.argsort(PV[:,2])]
    korif = np.zeros(numV, dtype=np.int64)
    insert_vec = np.ones(num_edge, dtype=bool)
    for i in reversed(range(num_edge)):
        akmi = PV[i][0:2]
        korif, c = createsCycle(korif, akmi)
 
        if not c:
            #Create maximum spanning tree's adjacency matrix
            Et[PV[i][0], PV[i][1]] = True
            Et[PV[i][1], PV[i][0]] = True
        
        else:
            insert_vec[i] = False #do not insert the edge if it introduces a circle
            
    #Calculate maximum spanning tree's weight
    W = np.sum(PV[insert_vec, 2])

    return Et, W

def visitTree(E, node, father, cliqParents, cliqChildren, order):

    """
    Makes a depth-first search of the tree to create the order
    and recursively fill in the cliqParents and cliqChildrenp.

    Inputs:
        - E : adjacency matrix of the junction tree
        - node : current node of the search
        - father : previous node of the search
        - cliqParents : dynamic list of the father of each node visited
        - cliqChildren : dynamic list of the children of each node visited
        - order : order of the nodes visited so far
    
    Outputs :
        - cliqParents, cliqChildren, order : updated with the new node visited
    """

    #depth-first search
    order.append(node)
    nei = np.nonzero(E[node, :])[0]
    children = nei[nei != father]
    if children.size>0:
        cliqChildren[node]=children.tolist()
        for c in children:
            cliqParents[c]=node
            cliqParents, cliqChildren, order = visitTree(E, c, node, cliqParents, cliqChildren, order)

    return cliqParents, cliqChildren, order

def listCliqueStates(Chd, Ced, unfixed, states):

    """
    Makes a depth-first search of the tree to create the order
    and recursively fill in the cliqParents and cliqChildrenp.

    Inputs:
        - E : adjacency matrix of the junction tree
        - node : current node of the search
        - father : previous node of the search
        - cliqParents : dynamic list of the father of each node visited
        - cliqChildren : dynamic list of the children of each node visited
        - order : order of the nodes visited so far
    
    Outputs :
        - cliqParents, cliqChildren, order : updated with the new node visited
    """
        
    if np.sum(unfixed)==0:
        return states

    node = np.argmax(unfixed)

    # set node to 0
    s0 = states.copy()
    setTo0=np.logical_and(Chd[node,:], unfixed)
    setTo0[node]=True
    s0[:, setTo0]=False
    s0 = listCliqueStates(Chd, Ced, np.logical_and(unfixed, np.logical_not(setTo0)), s0)

    # set node to 1
    s1 = states.copy()
    setTo0=np.logical_and(Ced[node], unfixed)
    setTo1=np.logical_and(Chd[:, node], unfixed)
    setTo1[node]=True
    setToValue=np.logical_or(setTo0, setTo1)
    
    s1[:, setTo0]=False
    s1[:, setTo1]=True
    s1 = listCliqueStates(Chd, Ced, np.logical_and(unfixed, np.logical_not(setToValue)), s1)

    return np.concatenate((s0, s1), axis=0)

def computeIndependantPartition(scores):
    local_norm = torch.add(scores, 1)
    return np.prod(local_norm)

def to_sympy(Eh, Ee, output_filename='ImgNet.sympy'):
    n = Eh.shape[0]

    with open(output_filename, 'w') as f:
        f.write('shape [{}]\n'.format(n))

        for i in range(n):
            for j in range(i+1, n):
                if Eh[i, j]:
                    f.write('X{} | ~X{}\n'.format(i, j))
                if Eh[j, i]:
                    f.write('X{} | ~X{}\n'.format(j, i))
                if Ee[i, j]:
                    f.write('~X{} | ~X{}\n'.format(i, j))


# SPARSE UTILS

def tile_and_cat(t, bs):
    return torch.cat((torch.tile(t, dims=(bs, 1)), torch.arange(start=0, end=bs).unsqueeze(dim=1)), dim=1)