"""Generate a dataset of various types of synthetic graphs + captions"""

import logging
import os
import pickle as pkl
import random

import networkx as nx

from collections import Counter
from functools import partial
from itertools import product
from operator import itemgetter


from build_datasets.synthetic.synthetic_utils import (
    default_order,
    bfs_node_order,
    dfs_node_order,
    make_captions,
    clique_ring_graph,
    clique_size_noise,
    starlike_tree,
    banana_tree,
    firecracker_graph,
    sunlet_graph,
    helm_graph,
    fan_graph,
)


logging.basicConfig(level=logging.INFO)
_LOG = logging.getLogger(__name__)


###############################################################################
#####                               UTILS                                 #####
###############################################################################


def reordered_and_directed(graph, order_fn):
    """Re-order nodes according to `order_fn` and convert to directed graph"""
    graph = order_fn(graph)
    adj = nx.to_numpy_array(graph)
    return nx.from_numpy_array(adj, create_using=nx.DiGraph), adj


def is_graph_ok(graph, node_min, node_max):
    """Check graph is the right size, not too dense, only 1 component"""
    max_density = 0.5
    density_ok = nx.density(graph) <= max_density
    size_ok = node_min <= len(graph) <= node_max
    components_ok = nx.number_connected_components(nx.to_undirected(graph)) == 1
    return density_ok and size_ok and components_ok


def limit_num_items(max_items, graphs, texts):
    """Pick items, make sure indices same in both graph/text lists"""

    assert len(graphs) == len(texts)
    num_items = len(graphs)

    if num_items <= max_items or max_items == 0:
        return graphs, texts

    # It's easier to shuffle indices then filter both lists
    indices = list(range(num_items))
    indices = random.sample(indices, max_items)
    indices = sorted(indices)

    indices_getter = itemgetter(*indices)
    graphs_sampled = indices_getter(graphs)
    texts_sampled = indices_getter(texts)

    assert len(graphs_sampled) == len(texts_sampled) == max_items  ###

    return graphs_sampled, texts_sampled


###############################################################################
#####                              DATASET                                #####
###############################################################################


def make_synthetic_dataset(
    node_min: int=20,
    node_max: int=160,
    max_each_class: int=0,
    save_folder: str="raw_datasets",
    file_name: str="synthetic.pkl",
    overwrite_if_exists: bool=False,
    seed: int=0,
):
    """Create synthetic dataset and save to file

    Args:
      node_min: Minimum nodes that graphs can have (Default value = 20)
      node_max: Maximum nodes that graphs can have (Default value = 160)
      max_each_class: Max number of graphs in each class (Default value = 1000)
      save_folder: Folder to save dataset (Default value = "raw_datasets")
      file_name: Dataset file name (Default value = "synthetic.pkl")
      overwrite_if_exists: Whether to overwrite if a dataset with the given name already exists
          in the folder (Default value = False)
      seed: Random seed (Default value = 0)
      
    """

    graphs, texts = [], []
    random.seed(seed)

    num_texts = 1  # Hardcoding `num_texts` to 1, this idea wasn't pursued
    total = 0

    # If dataset exists and overwrite is False then stop
    dataset_path = os.path.join(save_folder, file_name)
    if not os.path.isdir(save_folder):
        os.mkdir(save_folder)
    if os.path.exists(dataset_path) and not overwrite_if_exists:
        raise FileExistsError(
            f"File at {dataset_path} already exists. Set `--overwrite` to allow overwriting"
        )

    _LOG.info(f"Making synthetic dataset, graph size range: {node_min} - {node_max} nodes")

    ##### BARBELL #####

    bb_bell_min, bb_bell_max = 8, 40
    bb_bar_min, bb_bar_max = 6, 36

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "barbell"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for bell_size, bar_size, order_fn in product(
        range(bb_bell_min, bb_bell_max + 1),
        range(bb_bar_min, bb_bar_max + 1),
        node_orders,
    ):

        graph = nx.barbell_graph(bell_size, bar_size)
        graph, adj = reordered_and_directed(graph, order_fn)
        if not is_graph_ok(graph, node_min, node_max):
            continue

        text = make_captions(
            graph, graph_class, num=num_texts, bell_size=bell_size, bar_size=bar_size
        )
        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num barbells:     {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### LOLLIPOP #####

    head_min, head_max = 20, 65
    stick_min, stick_max = 10, 45

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "lollipop"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for head_size, stick_size, order_fn in product(
        range(head_min, head_max + 1),
        range(stick_min, stick_max + 1),
        node_orders,
    ):

        graph = nx.lollipop_graph(head_size, stick_size)
        graph, adj = reordered_and_directed(graph, order_fn)
        if not is_graph_ok(graph, node_min, node_max):
            continue

        text = make_captions(
            graph, graph_class, num=num_texts, head_size=head_size, stick_size=stick_size
        )
        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num lollipops:    {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### CLIQUE RING #####

    num_each = 3
    num_cq_min, num_cq_max = 4, 10
    cq_size_min, cq_size_max = 9, 16
    add_noise = True
    noise_frequency = 0.2

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "clique_ring"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for ne, num_cliques, clique_size, order_fn in product(
        range(num_each),
        range(num_cq_min, num_cq_max + 1),
        range(cq_size_min, cq_size_max + 1),
        node_orders,
    ):

        if clique_size * num_cliques < node_min:
            continue

        if add_noise and ne != 0:  # Some cliques will have different sizes
            clique_sizes = clique_size_noise(clique_size, num_cliques, noise_frequency)
        else:  # Otherwise all cliques are the same size
            clique_sizes = [clique_size] * num_cliques

        graph = clique_ring_graph(clique_sizes)
        graph, adj = reordered_and_directed(graph, order_fn)
        if not is_graph_ok(graph, node_min, node_max):
            continue

        text = make_captions(
            graph, graph_class, num=num_texts, num_cliques=num_cliques, clique_size=clique_size
        )

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num clique rings: {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### 2D GRID #####

    square_min, square_max = 3, 32

    node_orders = (
        # default_order,
        # partial(bfs_node_order, start_node=(0,0)),
        partial(dfs_node_order, start_node=(0, 0)),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "2d_grid"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for x, order_fn in product(
        range(square_min, square_max + 1),
        # range(square_min, square_max+1),
        node_orders,
    ):

        for y in range(square_min, x + 1):
            graph = nx.grid_2d_graph(x, y)
            graph, adj = reordered_and_directed(graph, order_fn)
            if not is_graph_ok(graph, node_min, node_max):
                continue

            text = make_captions(graph, graph_class, num=num_texts, x=x, y=y)

            class_graphs.append(graph)
            class_texts.append(text)
            num_in_class += 1
            total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num 2d grids:     {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### TRIANGULAR GRID #####

    tri_min, tri_max = 4, 24

    node_orders = (
        # default_order,
        partial(bfs_node_order, start_node=(0, 0)),
        # partial(dfs_node_order, start_node=(0,0)),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "tri_grid"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for x, order_fn in product(
        range(tri_min, tri_max + 1),
        # range(tri_min, tri_max+1),
        node_orders,
    ):

        for y in range(tri_min, x + 1):
            graph = nx.triangular_lattice_graph(x, y)
            graph, adj = reordered_and_directed(graph, order_fn)
            if not is_graph_ok(graph, node_min, node_max):
                continue

            text = make_captions(graph, graph_class, num=num_texts, x=x, y=y)

            class_graphs.append(graph)
            class_texts.append(text)
            num_in_class += 1
            total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num triangles:    {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### WATTS STROGATZ #####

    watts_size_min, watts_size_max = max(node_min, 30), node_max
    watts_deg_min, watts_deg_max = 6, 12
    rewiring_probs = [0.002, 0.01]

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "watts"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for num_nodes, degs, p, order_fn in product(
        range(watts_size_min, watts_size_max + 1),
        range(watts_deg_min, watts_deg_max + 1, 2),
        rewiring_probs,
        node_orders,
    ):

        graph = nx.connected_watts_strogatz_graph(num_nodes, degs, p=p, tries=500)

        # `reordered_and_directed` was occasionally causing issues here
        adj = nx.to_numpy_array(graph)
        graph = nx.from_numpy_array(adj, create_using=nx.DiGraph)
        graph = order_fn(graph)
        if not is_graph_ok(graph, node_min, node_max):
            continue

        text = make_captions(graph, graph_class, num=num_texts, degs=degs, p=p)

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num watts:        {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### STAR #####

    spokes_min, spokes_max = node_min - 1, 100

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "star"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for num_spokes, order_fn in product(
        range(spokes_min, spokes_max),
        node_orders,
    ):

        graph = nx.star_graph(num_spokes)
        graph, adj = reordered_and_directed(graph, order_fn)
        if not is_graph_ok(graph, node_min, node_max):
            continue

        text = make_captions(graph, graph_class, num=num_texts)

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num stars:        {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### PATH #####

    size_min, size_max = node_min, node_max

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "path"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for num_nodes, order_fn in product(
        range(size_min, size_max + 1),
        node_orders,
    ):

        graph = nx.path_graph(num_nodes)
        graph, adj = reordered_and_directed(graph, order_fn)
        if not is_graph_ok(graph, node_min, node_max):
            continue

        text = make_captions(graph, graph_class, num=num_texts)

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num path:         {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### BINARY TREE #####

    size_min, size_max = node_min, node_max

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "binary"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for num_nodes, order_fn in product(
        range(size_min, size_max + 1),
        node_orders,
    ):

        graph = nx.full_rary_tree(2, num_nodes)
        graph, adj = reordered_and_directed(graph, order_fn)
        if not is_graph_ok(graph, node_min, node_max):
            continue

        text = make_captions(graph, graph_class, num=num_texts)

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num binary tree:  {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### TERNARY TREE #####

    size_min, size_max = node_min, node_max

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "ternary"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for num_nodes, order_fn in product(
        range(size_min, size_max + 1),
        node_orders,
    ):

        graph = nx.full_rary_tree(3, num_nodes)
        graph, adj = reordered_and_directed(graph, order_fn)
        if not is_graph_ok(graph, node_min, node_max):
            continue

        text = make_captions(graph, graph_class, num=num_texts)

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num ternary tree: {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### STARLIKE TREE #####

    spokes_min, spokes_max = 7, 30
    spoke_len_min, spoke_len_max = 2, 6

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "starlike"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for num_spokes, spoke_len, order_fn in product(
        range(spokes_min, spokes_max + 1),
        range(spoke_len_min, spoke_len_max + 1),
        node_orders,
    ):

        graph = starlike_tree(num_spokes, spoke_len)
        if not is_graph_ok(graph, node_min, node_max):
            continue
        graph = order_fn(graph)
        adj = nx.to_numpy_array(graph)
        graph = nx.from_numpy_array(adj, create_using=nx.DiGraph)

        text = make_captions(
            graph, graph_class, num=num_texts, num_spokes=num_spokes, spoke_len=spoke_len
        )

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num starlikes:    {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### BANANA TREE #####

    stars_min, stars_max = 3, 15
    star_size_min, star_size_max = 5, 15

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "banana_tree"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for num_stars, star_size, order_fn in product(
        range(stars_min, stars_max + 1),
        range(star_size_min, star_size_max + 1),
        node_orders,
    ):

        graph = banana_tree(num_stars, star_size)
        if not is_graph_ok(graph, node_min, node_max):
            continue
        graph = order_fn(graph)
        adj = nx.to_numpy_array(graph)
        graph = nx.from_numpy_array(adj, create_using=nx.DiGraph)

        text = make_captions(
            graph, graph_class, num=num_texts, num_stars=num_stars, star_size=star_size
        )

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num banana trees: {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### FIRECRACKER GRAPH #####

    stars_min, stars_max = 3, 15
    star_size_min, star_size_max = 5, 15

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "firecracker"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for num_stars, star_size, order_fn in product(
        range(stars_min, stars_max + 1),
        range(star_size_min, star_size_max + 1),
        node_orders,
    ):

        graph = firecracker_graph(num_stars, star_size)
        if not is_graph_ok(graph, node_min, node_max):
            continue
        graph = order_fn(graph)
        adj = nx.to_numpy_array(graph)
        graph = nx.from_numpy_array(adj, create_using=nx.DiGraph)

        text = make_captions(
            graph, graph_class, num=num_texts, num_stars=num_stars, star_size=star_size
        )

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num firecrackers: {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### SUNLET GRAPH #####

    cycle_min, cycle_max = 4, node_max // 2

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "sunlet"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for cycle_length, order_fn in product(
        range(cycle_min, cycle_max + 1),
        node_orders,
    ):

        graph = sunlet_graph(cycle_length)
        if not is_graph_ok(graph, node_min, node_max):
            continue
        graph = order_fn(graph)
        adj = nx.to_numpy_array(graph)
        graph = nx.from_numpy_array(adj, create_using=nx.DiGraph)

        text = make_captions(graph, graph_class, num=num_texts, cycle_length=cycle_length)

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num sunlets:      {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### HELM GRAPH #####

    wheel_min, wheel_max = 5, node_max // 2

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "helm"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for wheel_size, order_fn in product(
        range(wheel_min, wheel_max + 1),
        node_orders,
    ):

        graph = helm_graph(wheel_size)
        if not is_graph_ok(graph, node_min, node_max):
            continue
        graph = order_fn(graph)
        adj = nx.to_numpy_array(graph)
        graph = nx.from_numpy_array(adj, create_using=nx.DiGraph)

        text = make_captions(graph, graph_class, num=num_texts, wheel_size=wheel_size)

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num helms:        {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")

    ##### FAN GRAPH #####

    k_min, k_max = 1, 3
    path_min, path_max = 15, 80

    node_orders = (
        default_order,
        # partial(bfs_node_order, start_node=0),
        # partial(dfs_node_order, start_node=0),
        # partial(bfs_node_order, start_highest_deg=True),
        # partial(dfs_node_order, start_highest_deg=True),
    )

    graph_class = "fan"
    num_in_class = 0
    class_graphs, class_texts = [], []
    for k, path_len, order_fn in product(
        range(k_min, k_max + 1),
        range(path_min, path_max + 1),
        node_orders,
    ):

        graph = fan_graph(km=k, path_len=path_len)
        if not is_graph_ok(graph, node_min, node_max):
            continue
        graph = order_fn(graph)
        adj = nx.to_numpy_array(graph)
        graph = nx.from_numpy_array(adj, create_using=nx.DiGraph)

        text = make_captions(graph, graph_class, num=num_texts, k_size=k, path_len=path_len)

        class_graphs.append(graph)
        class_texts.append(text)
        num_in_class += 1
        total += 1

    class_graphs, class_texts = limit_num_items(max_each_class, class_graphs, class_texts)
    num_in_class = len(class_graphs)
    graphs += class_graphs
    texts += class_texts
    sizes = [len(i) for i in class_graphs]
    _LOG.info(f"Num fans:         {num_in_class:4d},  sizes {min(sizes)}-{max(sizes)}")


    ### Checks and save to file ###

    assert all(node_min <= len(g) <= node_max for g in graphs)
    assert len(graphs) == len(texts)

    # Save dataset
    with open(dataset_path, "wb") as f:
        pkl.dump(list(zip(graphs, texts)), f)

    sizes = [len(i) for i in graphs]
    _LOG.info(
        f"\nMade synthetic graphs, kept {len(graphs)} / {total}. Sizes range: {min(sizes)}-{max(sizes)}"
        f". Unique texts: {len(set(texts))} / {len(texts)}\n"
    )

    # Check for identical texts on different graphs
    all_texts = [j for i in texts for j in i] if num_texts > 1 else texts
    duplicate_texts = [i for i, j in Counter(all_texts).items() if j > 1]
    if duplicate_texts:
        print(f"\n\n***** Texts that appear more than once *****\n")
        print("\n\n".join(duplicate_texts))
