"""Datasets generated by Bayesian networks."""
import dataclasses
import random
from typing import Sequence

import numpy as np
import tensorflow as tf


@dataclasses.dataclass
class Node:
    pass


@dataclasses.dataclass
class BayesianNetwork:
    pass


def _connect_sources_to_sinks(
    n_sources: int,
    n_sinks: int,
    n_intermediate_nodes: int,
    p_connection: float,
) -> Sequence[Sequence[int]]:
    total_nodes = n_sources + n_sinks + n_intermediate_nodes

    # Topologically sorted by source node.
    connections = [[] for _ in range(total_nodes)]

    for src in range(n_sources + n_intermediate_nodes):
        for dst in range(max(src, n_sources), total_nodes):
            if random.random() <= p_connection:
                connections[src].append(dst)

    return connections


def _remove_extraneous_sources(
    connections: Sequence[Sequence[int]],
    n_sources: int,
) -> Sequence[Sequence[int]]:
    # Only include nodes reachable from the sources.
    visited = set()
    stack = list(range(n_sources))
    while stack:
        idx = stack.pop()
        if idx in visited:
            continue
        visited.add(idx)
        stack.extend(connections[idx])

    remap = []
    i = 0
    for j in range(len(connections)):
        remap.append(i)
        if j in visited:
            i += 1

    return [
        [remap[j] for j in c]
        for i, c in enumerate(connections)
        if i in visited
    ]
