import numpy as np
import scipy.linalg as la
import numpy.linalg as linalg
import math

def edge_hashing(edge, n):
    
    if isinstance(edge, int):
        return edge
    
    edge_ls = []
    for e in edge:
        if isinstance(e, int):
            edge_ls.append(e)
        else:
            edge_ls.append(edge_hashing(e, n))
    
    edge_ls.sort(reverse = True)
    return_value = 0
    for i in range(len(edge_ls)):
        return_value += edge_ls[i] +  (n*i)
    return return_value


def create_edge_dict(connectionSet, n):
    possible_hashes = np.array([edge_hashing(edge, n) for edge in connectionSet])
    indices = np.argsort(possible_hashes)
    edge_dict_complete = {}
    for i in indices:
        edge_dict_complete[possible_hashes[i]] = i
    
    set_ls = []
    for i in range(n):
        set_ls.append(set())
    
    for edge in connectionSet:
        set_ls[len(edge)-1].add(edge)

    possible_hashes_ls = [np.array([edge_hashing(edge, n) for edge in set_ls[i]]) for i in range(n)]
    indices_ls = [np.argsort(possible_hashes_ls[i]) for i in range(n)]
    edge_dict_bucket = {}
    for i in range(n):
        for j in indices_ls[i]:
            edge_dict_bucket[possible_hashes_ls[i][j]] = j
    
    tuple_dict = {}
    for h in possible_hashes:
        tuple_dict[h] = (edge_dict_complete[h] ,edge_dict_bucket[h])
    
    return tuple_dict
        
    
    
def array_of_edge_numbers(connectionSet, n):
    return_array = np.zeros(n)
    for edge in connectionSet:
        return_array[len(edge)-1] += 1
    cum_return_array = np.cumsum(return_array)
    return return_array, cum_return_array
    
    

def powerset_size(i, n):
    result = 0
    for size in range(1,i+1):
        result += math.comb(n,size)
        
    return result

def indicens_matrix(connectionSet, rank = 1, n = 100, edge_dict = None, length_array = None, cum_array = None):
    # 1) Create a dictionary of nodes
    #TODO probably change here the size
    int(cum_array[rank-1])
    length_array[rank-1]
    incident_matrix = np.zeros((int(cum_array[rank-1]), int(length_array[rank-1])))
    for edge in connectionSet:
        if len(edge) == rank:
            (_, b_ind) = edge_dict[edge_hashing(edge, n)]
            if len(edge) == 1:
                incident_matrix[edge, edge] = 1
            else:
                for sub_edge in edge:
                    ( _, sub_edge_index ) = edge_dict[edge_hashing(sub_edge, n)]
                    try: 
                        incident_matrix[sub_edge_index, b_ind] = 1
                    except:
                        print(sub_edge, edge)
                        print(sub_edge_index, b_ind)
                        print(incident_matrix.shape)
                        print("========================")
                        raise Exception("Error")
            
            
    return incident_matrix

def node_to_edge_projection(connectionSet, n = 100, level = 2, edge_dict = None, length_array = None):
    # y is number of edges with a maximum number of elements of level 
    y = length_array[level-1]
    projection_matrix = np.zeros((int(n), int(y)))
    
    def recursive_index(edge, projection_matrix, edge_dict, edge_index, factor):
        if isinstance(edge, int):
            try:
                projection_matrix[edge, edge_index] = factor
            except:
                print(edge, edge_index)
                print(projection_matrix.shape)
                print("============????????????============")
                raise Exception("Error")
                
            
            return
        if len(edge) == 1:
            (ind, _) = edge_dict[edge_hashing(edge, n)]
            try:
                projection_matrix[edge[0], edge_index] = factor
            except:
                print(edge)
                print(edge_index)
                print(projection_matrix.shape)
                print("============!!!!!!!!============")
                raise Exception("Error")
            return
        else:
            for subedge in edge:
                if isinstance(subedge, int):
                    temp = factor 
                else:
                    temp = factor * (1 / len(edge))
                recursive_index(subedge, projection_matrix, edge_dict, edge_index, temp)
            return 
    
    for edge in connectionSet:
        if len(edge) == level:
            edge_index = edge_dict[edge_hashing(edge, n)][1]
            recursive_index(edge, projection_matrix, edge_dict, edge_index, 1/level)
    
    return projection_matrix

def real_laplacian_matrix_zinc(connectionSet, n = 100):
    max_depth = 0
    
    # we know that we only have edges:
    adja_matrix = np.zeros((n,n))
    for edge in connectionSet:
        if len(edge) == 2:
            adja_matrix[edge[0], edge[1]] = 1
            adja_matrix[edge[1], edge[0]] = 1
    adja_matrix[adja_matrix > 0] = 1
    degree = np.sum(adja_matrix, axis = 1)
    L = np.diag(degree) - adja_matrix
    
    return L
    
def real_laplacian_matrix(connectionSet, n = 100):
    max_depth = 0
    for edge in connectionSet:
        if len(edge) > max_depth:
            max_depth = len(edge)
    edge_dict = create_edge_dict(connectionSet, n)
    length_array, cum_array = array_of_edge_numbers(connectionSet, n)

    L = np.zeros((n,n))
    for i in range(1, max_depth + 1):
        I = indicens_matrix(connectionSet, i, n, edge_dict, length_array, cum_array)
        P = node_to_edge_projection(connectionSet, n, i, edge_dict, length_array)
        L += (1/max_depth) * (P @ (I.T @ I) @ linalg.pinv(P))
    
    return L

def Laplacian_matrix(connectionSet, n = 100):
    # 1) Create a dictionary of nodes
    A = np.zeros((n,n))
    for edge in connectionSet:
        for i in range(len(edge)):
            for j in range(i+1, len(edge)):
                A[edge[i], edge[j]] -= 1 / len(edge)
                A[edge[j], edge[i]] -= 1 / len(edge)
        
        for e in edge:
            A[e,e] += 1
    
    return A

def heat_kernel(L, t):
    # we start for now with the non optimised version
    # print(t)
    # print(L)
    return la.expm(-t*L)

def eval_kernels_ZINC(Lls):
    number_of_elements = Lls[0].shape[0]
    signatures = np.zeros((number_of_elements,len(Lls)))
    for i in range(number_of_elements):
        x = np.zeros(number_of_elements)
        x[i] = 100
        # x *= 100
        
        for ind, k in enumerate(Lls):
            temp = x.T @ k @ x
            
            sum_int = k @ x
            sum_int = sum_int.sum()
            
           
            # sum_int = 0        
            # for j in range(number_of_elements):
            #     x_prime = np.zeros(number_of_elements)
            #     x_prime[j] = 100
            #     temp_int = x_prime.T @ k @ x_prime
            #     sum_int += temp_int
            # if sum_int == 0:
            #     temp = 0
            # else:
            #     temp = temp / sum_int    
            if sum_int == 0:
                temp = 0
            else:
                temp = temp / sum_int
            signatures[i, ind] = temp
    signatures = signatures / np.linalg.norm(signatures, axis=1, keepdims=True)
    return signatures

def eval_kernels(Lls):
    number_of_elements = Lls[0].shape[0]
    signatures = np.zeros((number_of_elements,len(Lls)))
    for i in range(number_of_elements):
        x = np.zeros(number_of_elements)
        x[i] = 100
        # x *= 100
        
        for ind, k in enumerate(Lls):
            temp = x.T @ k @ x
            
            sum_int = k @ x
            sum_int = sum_int.sum()
            
           
            # sum_int = 0        
            # for j in range(number_of_elements):
            #     x_prime = np.zeros(number_of_elements)
            #     x_prime[j] = 100
            #     temp_int = x_prime.T @ k @ x_prime
            #     sum_int += temp_int
            # if sum_int == 0:
            #     temp = 0
            # else:
            #     temp = temp / sum_int    
            # if sum_int == 0:
            #     temp = 0
            # else:
            #     temp = temp / sum_int
            signatures[i, ind] = temp
    signatures = signatures / np.linalg.norm(signatures, axis=1, keepdims=True)
    return signatures


def calc_diffusuion_pattern(Lls):
    number_of_elements = Lls[0].shape[0]
    patterns = np.zeros((number_of_elements))
    patterns[0] = 100000
    res_ls = []
    for ind, k in enumerate(Lls):
        temp =  k @ patterns
        res_ls.append(temp)
    patterns = np.stack(res_ls, axis = 0)
    return patterns.T    