import numpy as np
import pdb
import random
import copy
from scipy.special import binom

'''
Misceleneaous functionalities with random graphs
'''

def random_edges(vertices, numedges = 20, mode = 'undirected', io_order=None):
    edgelist = []
    degree_list = np.zeros(len(vertices))
    Adj_matrix = np.zeros([len(vertices), len(vertices)])
    non_full_degree_list = copy.copy(vertices)
    isClique = False
    edgecount = 0
    if mode is 'undirected':
        numedges = np.min([binom(len(vertices), 2), numedges])
    elif mode is 'lingam':
        numedges = np.min([binom(len(vertices), 2), numedges])
    elif mode is 'directed':
        numedges = np.min([2* binom(len(vertices), 2), numedges])
    else:
        raise NotImplementedError
    stop = False
    while edgecount < numedges and not isClique and not stop:
        v_in_candidates  = copy.copy(vertices)
        v_out = np.random.choice(non_full_degree_list)
        vertices_filled = np.where(Adj_matrix[v_out] == 1)[0].tolist()
        v_in_candidates.remove(v_out)
        v_in_candidates = list(set(v_in_candidates) - set(vertices_filled))
        if len(v_in_candidates) > 0:
            v_in  = np.random.choice(v_in_candidates)
            Adj_matrix[v_out, v_in] = 1
            edgelist.append([v_out, v_in])
            degree_list[v_out] += 1
            edgecount += 1
            if mode is 'undirected':
                Adj_matrix[v_in, v_out] = 1
                edgelist.append([v_in, v_out])
                degree_list[v_in] += 1
            #If lower triangular, cancel the thing
            elif mode is 'lingam':
                if v_in <= v_out:
                    Adj_matrix[v_out, v_in] = 0
                    edgelist.pop(-1)
                    degree_list[v_out] -= 1
                    edgecount -= 1

            if degree_list[v_out] == (len(vertices) - 1):
                non_full_degree_list.remove(v_out)
            if len(non_full_degree_list) == 0:
                isClique = True
        else:
            stop = True
    if io_order == 'reverse':
        Adj_matrix = np.transpose(Adj_matrix)

    return edgelist,  degree_list, Adj_matrix


def enter_edges(edgelist, adj_matrix, handshake=False):
    for k in range(len(edgelist)):
        adj_matrix[edgelist[k][0], edgelist[k][1]] = 1
        if handshake:
            adj_matrix[edgelist[k][1], edgelist[k][0]] = 1
    return adj_matrix


def enter_coeffs(coeffs, adj_matrix):
    numedges = len(np.where(adj_matrix != 0)[0])
    if len(coeffs) != numedges:
        print("number of coeffs must be same as the numbe of edges")
    coeff_index = 0
    for k in range(len(adj_matrix)):
        for j in range(len(adj_matrix)):
            if adj_matrix[k][j] !=  0:
                adj_matrix[k][j] = coeffs[coeff_index]
                coeff_index += 1
    return adj_matrix


'''
edgeset = np.array([[1,0], [2,0], [1,3], [0,3]])
vertexset = np.array(range(5))
dvals, dkeys = gu.compute_degree(vertexset, edgeset)
print(dvals, dkeys)
[0 0 1 2 1] [3 4 0 1 2]

dvals, dkeys = gu.compute_degree(vertexset, edgeset, sort=True)
print(dvals, dkeys)

[0 0 1 1 2] [3 4 0 2 1]

dvals, dkeys = gu.compute_degree(vertexset, edgeset, sort=True, handshake=True)
print(dvals, dkeys)
array([0, 1, 1, 1, 2]), array([4, 0, 2, 3, 1])
'''

def compute_degree(vertexset, edgeset, sort=False,
                   mode='out_degree', handshake=False,
                   sortorder=None):
    adj_vertexset= set(np.concatenate(edgeset))
    degrees = {}
    for vertex in adj_vertexset:
        for edge in edgeset:
            if mode == 'out_degree':
                parity = 0
            elif mode == 'in_degree':
                parity = 1
            else:
                NotImplementedError

            if edge[parity] == vertex:
                if vertex in degrees.keys():
                    degrees[vertex] += 1
                else:
                    degrees.update({vertex: 1})

                if handshake and edge[(parity + 1)%2] in degrees.keys():
                    degrees[edge[(parity + 1)%2]] += 1

                elif handshake and edge[(parity + 1)%2] not in degrees.keys():
                    degrees.update({edge[(parity + 1)%2]: 1})
                else:
                    pass

    if sortorder is None:
        order = 1
    elif sortorder == 'reverse' or sortorder == 'descend':
        order = -1
    else:
        raise NotImplementedError

    for vertex in vertexset:
        if vertex not in degrees.keys():
            degrees.update({vertex : 0})

    if sort:
        sorting_order= np.argsort(order*np.array(list(degrees.values())))

    else:
        sorting_order = np.array(range(len(degrees.values())))

    degrees_val_array = np.array(list(degrees.values()))[sorting_order]
    degrees_key_array = np.array(list(degrees.keys()))[sorting_order]

    return degrees_val_array, degrees_key_array


'''
in_vertices_analysis

'''
def in_vertices_analysis(vertexset, edgeset):
    in_vertices = {}
    for vertex in vertexset:
        in_vertices.update({vertex:[]})
    for edge in edgeset:
        in_vertices[edge[1]] += [edge[0]]
    return in_vertices


'''
in_vertices_analysis

'''
def out_vertices_analysis(vertexset, edgeset):
    out_vertices = {}
    for vertex in vertexset:
        out_vertices.update({vertex:[]})
    for edge in edgeset:
        out_vertices[edge[0]] += [edge[1]]
    return out_vertices


'''
Helper function for LingamSort.
Make the permutation such that
child node is always AFter the parentnodes.
'''
def lingaminsert(node, parents, nodearray):
    if len(nodearray) == 0:
        nodearray = np.concatenate([[node],  nodearray])

    else:
        parent_positions = []
        for k in range(len(parents)):
            if len(np.where(nodearray == parents[k])[0]) > 0:
                parent_positions += [np.where(nodearray == parents[k])[0][0]]
        if len(parent_positions) == 0:
            nodearray = np.concatenate([[node], nodearray])
        else:
            max_parent_position = np.max(parent_positions) + 1
            nodearray = np.insert(nodearray, max_parent_position, node)

    return nodearray


'''
Do the LingamSort:

Starting from one_hot vector with support on the 
set of vertices with "out_degree one",  iteratively augment the vertex set.
'''
def lingamsort(vertexset, edgeset):
    in_vertices = in_vertices_analysis(vertexset, edgeset)
    permutation = []

    for vertex in range(len(vertexset)):
        permutation = lingaminsert(vertex, in_vertices[vertex], permutation)

    return permutation.astype(int)


'''
Helper function.
transposition
'''
def transposition(from_loc, to_loc, targetarray):
    from_loc_var = targetarray[from_loc]
    to_loc_var = targetarray[to_loc]
    targetarray[to_loc] = from_loc_var
    targetarray[from_loc] = to_loc_var
    return targetarray
'''
create a permutation mapping from keys to sorted keys
>gu.make_permutation([1,3,5], [5,3,1])
>array([2, 1, 0])
'''
def make_permutation(keys, sorted_keys):
    permutation = []
    keys= np.array(keys)
    sorted_keys = np.array(sorted_keys)
    for k in range(len(keys)):
        permutation += [np.where(sorted_keys == keys[k])[0][0]]
    return np.array(permutation)

'''
apply the permutation.
unko = np.array(range(9)).reshape((3,3)) 
gu.apply_permutation(permute, unko)
array([[0, 2, 1],
       [6, 8, 7],
       [3, 5, 4]])
'''

def apply_permutation(permutation, target, mode='adj'):
    #permuation for adj_matrix
    if mode == 'adj':
        target = target[:, permutation][permutation, :]
    #permuation for edgelist
    elif mode == 'edgelist':
        for edge in target:
            edge[0] = permutation[edge[0]]
            edge[1] = permutation[edge[1]]
    #permuation for vertexlist
    elif mode == 'vertexlist':
        target = target[permutation]
    return target


'''
Make adjacency matrix for user-specified sort-order.
Use Lingamsort for Lingam


edgeset = np.array([[1,0], [2,0], [1,3], [0,3]])
vertexset = np.array(range(5))
(array([0, 0, 1, 1, 2]),
 array([3, 4, 0, 2, 1]),
 array([[ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.],
        [ 1.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  1.,  0.,  0.],
        [ 1.,  0.,  1.,  0.,  0.]]))


edgelist = [[1,0], [3,0], [4,0], [2,1], [4,1], [3,4]]
vertices = list(range(5))
deg, idx, mat = gu.make_adj(vertices, edgelist, mode='lingamsort', io_order='reverse')
[[ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  1.  0.  0.  0.]
 [ 1.  0.  1.  0.  0.]
 [ 0.  1.  1.  1.  0.]]
0
[2 3 4 1 0]
'''

def make_adj(vertexset, edgeset, handshake=False,
             mode='vanilla', sortorder=None, io_order=None,
             seed = None):
    is_DAG_tree = True

    adj_matrix = np.zeros(shape=(len(vertexset), len(vertexset)))
    adj_matrix = enter_edges(edgeset, adj_matrix, handshake=handshake)

    if seed is not None:
        np.random.seed(seed)
    #compute the permutation
    if mode == 'vanilla':
        val_array = 0
        key_array = list(range(len(vertexset)))
        adj_matrix = enter_edges(edgeset, adj_matrix, handshake=handshake)
    elif mode == 'degreesort':
        val_array, key_array = \
            compute_degree(vertexset, edgeset, sort=True,
                           handshake=handshake, sortorder=sortorder)
    elif mode == 'lingamsort':
        val_array = 0
        key_array = lingamsort(vertexset, edgeset)
    else:
        raise NotImplementedError

    adj_matrix = apply_permutation(key_array, adj_matrix)

    if io_order == 'reverse':
        adj_matrix = np.transpose(adj_matrix)

    return val_array, key_array , adj_matrix


def random_partitions(vertices, numgroups):
    assignments = np.random.choice(range(numgroups), size = len(vertices))
    assign_mat = np.zeros([len(vertices), numgroups])
    for k in range(numgroups):
        assign_mat[:,k] = np.array([assignments == k])
    print(assignments)
    return assign_mat



def is_lower_triangular(adj_matrix, verbose=False):
    is_so = True
    for k in range(len(adj_matrix)):
        for j in range(k, len(adj_matrix)):
            if adj_matrix[k][j] != 0:
                is_so = False
    if verbose and not is_so:
        print('NOT lower triangular!')
    return is_so

