import numpy as np

from typing import List, Tuple, Dict, Any


def get_inference_order(
    graph: Tuple[Tuple[Any, Tuple[Any, ...]], ...]
) -> List[Any]:
    """From a graph, derives an inference order so that
    at the time we calculate the value corresponding to an element all
    values for elements it depends upon have already been calculated

    Parameters
    ----------
    graph : Tuple[Tuple[Any, Tuple[Any, ...]], ...]
        describes the graph as (element, parents_of_element)

    Returns
    -------
    List[Any]
        order in which to process the elements
    """
    elements = [element for element, _ in graph]
    is_calculated = {element: False for element in elements}
    inference_order = []
    while len(inference_order) < len(elements):
        for element, parents_of_element in graph:
            if element in inference_order:
                pass
            elif all(
                [
                    is_calculated[parent]
                    for parent in parents_of_element
                ]
            ):
                inference_order.append(element)
                is_calculated[element] = True

    return inference_order


def ground_template(
    template_graph: Tuple[Tuple[str, Tuple[str, ...]], ...],
    plates_per_rv: Dict[str, List[str]],
    plate_cardinalities: Dict[str, int]
) -> Tuple[Tuple[str, Tuple[str, ...]], ...]:
    ground_graph = []
    template_per_rv = {}
    for rv, dependencies in template_graph:
        if len(plates_per_rv[rv]) > 0:
            dummy = np.empty(
                tuple(
                    plate_cardinalities[plate]
                    for plate in plates_per_rv[rv]
                )
            )
            for index, _ in np.ndenumerate(dummy):
                ground_rv = rv + "_".join(str(i) for i in index)
                ground_dependencies = []
                for dependency_rv in dependencies:
                    ground_dependency_rv = dependency_rv + "_".join(
                        str(index[i])
                        for i, plate in enumerate(plates_per_rv[rv])
                        if plate in plates_per_rv[dependency_rv]
                    )
                    ground_dependencies.append(
                        ground_dependency_rv
                    )
                ground_graph.append(
                    (ground_rv, tuple(ground_dependencies))
                )
                template_per_rv[ground_rv] = {
                    "template": rv,
                    "index": [*index]
                }
        else:
            ground_graph.append(
                (rv, dependencies)
            )
            template_per_rv[rv] = {
                "template": rv,
                "index": None
            }
    return tuple(ground_graph), template_per_rv
