import numpy as np
from scipy.sparse import lil_matrix
from scipy.sparse.csgraph import connected_components
from tqdm.auto import tqdm
from itertools import combinations


def retrieve_circuits(concepts, semantic_check, verbose=False):
    '''
    Given a set of units and their
    meanings it returns semantical
    consistent circuits.

    Parameters
    ----------
    concepts: dict of list
        Dictionary mapping the module
        name to a list of (Concept,
        IOU) tuples for each unit
    semantic_check: Callable
        function to determine if
        two meaning are semantically
        related
    verbose: bool, optional
        Prints to stdout the
        computation steps

    Returns
    -------
    circuits: list of circuits
        List containing for each
        circuit. Each circuit is
        composed by a list of its
        participants, expressed as
        a triple (module, unit_id, meaning)
    '''

    # List of (unit, concept) pairs
    pair_index = []

    # Keeps the index of the first
    # pair for each module
    start_module = []

    # Initialize the nodes of the graph
    idx = 0
    for module in concepts:
        start_module.append(idx)
        for unit, meanings in enumerate(concepts[module]):
            for meaning in meanings:
                pair_index.append((module, unit, meaning[1]))
                idx += 1
    start_module.append(idx)
    n_pairs = len(pair_index)

    if verbose:
        print(n_pairs, 'total nodes')

    # Initialize the edges of the graph
    arcs = lil_matrix((n_pairs, n_pairs), dtype=np.int8)

    modules = list(concepts.keys())

    # Progress bar
    total = sum([
            (start_module[i+1] - start_module[i])
            * (start_module[i+2] - start_module[i+1])
            for i, _ in enumerate(modules[:-1])])
    pbar = tqdm(total=total)

    for module_id, module in enumerate(modules[:-1]):
        # Iterate over the pair of units in adjacent layers
        for a_node in pair_index[
                start_module[module_id]:start_module[module_id+1]]:
            a_idx = pair_index.index(a_node)
            for b_node in pair_index[
                    start_module[module_id+1]:start_module[module_id+2]]:
                a_concept, b_concept = a_node[2], b_node[2]

                if semantic_check(a_concept, b_concept):
                    b_idx = pair_index.index(b_node)
                    arcs[a_idx, b_idx] = 1

                # Update progress
                pbar.update(1)

    # Terminate progress bar
    pbar.close()

    if verbose:
        print('Edges added')
        print('Computing connected components')

    n_components, labels = connected_components(csgraph=arcs,
                                                directed=False,
                                                return_labels=True)
    circuits = []
    for circuit in range(n_components):
        size = np.count_nonzero(labels == circuit)
        if size > 1:
            nodes = []
            for node in np.argwhere(labels == circuit):
                nodes.append(pair_index[int(node)])
            circuits.append(nodes)

    return circuits


def circuit_unique_meanings(circuit):
    '''
    Return the unique meanings
    from a circuit.
    '''
    meanings = set()
    for node in circuit:
        meanings.add(node[2])
    return meanings


def circuit_unique_units(circuit):
    '''
    Return the unique units
    from a circuit.
    '''
    units = set()
    for node in circuit:
        units.add((node[0], node[1]))
    return units


def circuit_coherence(circuit, sim, unique=False):
    '''
    Return the unique units
    from a circuit.

    Parameters
    ----------
    circuit: list of tuples
        Circuit expressed as a list
        of tuples (module, unit_id, meaning)
    sim: Callable
        Similarity function to determine
        the semantic similarity between
        two concepts
    unique: bool, optional
        If True it computes the similarity
        only between unique meanings.
        Defaults to false.

    Returns
    -------
    coherence: float
        Average of the similarities
        between circuit members.
    '''
    meanings = [node[2] for node in circuit]

    if unique:
        meanings = set(meanings)

    coherence = np.average([sim(i, j) for i, j in combinations(meanings, r=2)])

    return coherence
