"""Utils for making synthetic graph+text dataset"""


import random

import networkx as nx

from itertools import combinations

from build_datasets.synthetic.synthetic_class_descs import get_class_description


###############################################################################
#####                     GRAPH GENERATORS AND HELPERS                    #####
###############################################################################


def clique_ring_graph(cq_sizes):
    """Modified ring of cliques graph. Different to networkx implementation:
    - Cliques may have different sizes
    - Bridges go from last node (by id) in clique to first node in next clique
    """
    G = nx.Graph()
    bridge_edges = []
    for clique_size in cq_sizes:
        clique_edges = combinations(range(len(G), len(G) + clique_size), 2)
        G.add_edges_from(clique_edges)
        bridge_edges.append((len(G) - 1, len(G) % sum(cq_sizes)))
    G.add_edges_from(bridge_edges)
    adj = nx.to_numpy_array(G)
    return nx.from_numpy_array(adj, create_using=nx.DiGraph)


def clique_size_noise(normal_size, num_cliques, noise_frequency=0.05):
    """Add noise to the ring of cliques sizes"""

    noise_amount = 0  # Change back to 3

    sizes = tuple([normal_size] * num_cliques)
    if random.random() < 0.2:  # Keep all the same size
        return sizes

    new_sizes = []
    for s in sizes:
        # Add/remove nodes to the clique some of the time
        if random.random() < noise_frequency:
            s += random.choice((-noise_amount, noise_amount))
        new_sizes.append(s)
    return tuple(new_sizes)


def starlike_tree(num_spokes, spoke_length):
    """Starlike trees are "trees where exactly one of its nodes has degree greater
    than two". In other words, a star graph with *path graphs* coming out from the
    central node, rather than every node only being connected to the cente. Here,
    all spokes (path graphs) are forced to be the same length, although formally
    the spokes in starlike trees may vary in length
    """
    G = nx.Graph()
    G.add_node(0)
    # Could do with disjoint_union_all of a list of path graphs, this is fine too
    for s in range(num_spokes):
        new_node_ids = list(range(len(G), len(G) + spoke_length))
        # Spoke edges (path graph)
        new_edges = [(u, v) for u, v in zip(new_node_ids[:-1], new_node_ids[1:])]
        G.add_edges_from(new_edges)
        # Centre to spoke edge
        G.add_edge(0, new_node_ids[0])
    return G



def _star_graph(num_leaves):
    """Star graph but the first node is a leaf and the second is the centre"""
    graph = nx.star_graph(num_leaves)
    node_list = list(nx.bfs_tree(graph, source=1).nodes())
    return nx.from_numpy_matrix(nx.to_numpy_matrix(graph, nodelist=node_list))


def banana_tree(num_stars, star_size):
    """Banana tree graph - `num_stars` copies of star graphs with `star_size`
    leaves, and one leaf of each star joined to an additional root node.
    https://mathworld.wolfram.com/BananaTree.html
    """
    assert num_stars > 1 and star_size > 3

    g = nx.empty_graph(1)
    for s in range(num_stars):
        first_star_node_id = len(g)
        star = _star_graph(star_size - 1)
        g = nx.disjoint_union(g, star)
        g.add_edge(0, first_star_node_id)
    return g


def firecracker_graph(num_stars, star_size):
    """Firecracker graph - `num_stars` star graphs with `star_size` leaves.
    A leaf of each star is joined to a leaf of the the previous star.
    https://mathworld.wolfram.com/FirecrackerGraph.html
    """
    assert num_stars > 1 and star_size > 3

    g = nx.Graph()
    for s in range(num_stars):
        last_star_root = (s - 1) * star_size
        this_star_root = len(g)
        star = _star_graph(star_size - 1)
        g = nx.disjoint_union(g, star)
        if s > 0:
            g.add_edge(last_star_root, this_star_root)
    return g


def sunlet_graph(cycle_length):
    """Sunlet graph - a cycle graph with `cycle_length` nodes, but
    every node in the cycle has an extra leaf coming from it.
    https://mathworld.wolfram.com/SunletGraph.html
    """
    g = nx.empty_graph(1)
    for n in range(cycle_length):
        cur_node = len(g) - 1
        g.add_edge(cur_node, len(g))
        g.add_edge(cur_node, len(g) % (cycle_length * 2))
    return g


def helm_graph(wheel_size):
    """Helm graph - a wheel graph with `wheel_size` nodes, but every
    (non-centre) node in the wheel has an extra leaf coming from it.
    https://mathworld.wolfram.com/HelmGraph.html
    """

    g = nx.wheel_graph(wheel_size)

    for n in range(1, wheel_size):
        g.add_edge(n, len(g))
    return g


def fan_graph(km, path_len):
    """Fan graph - the graph join of the empty graph on `km` nodes
    and the path graph on `path_len` nodes.
    In short, a complete (almost) bipartite graph except the second
    set of nodes is connected in a path graph.
    https://mathworld.wolfram.com/FanGraph.html
    """
    x = nx.empty_graph(km)
    y = nx.path_graph(path_len)
    g = nx.full_join(x, y, rename=("g-", "h-"))

    node_list = list(g.nodes())
    return nx.from_numpy_matrix(nx.to_numpy_matrix(g, nodelist=node_list))


###############################################################################
#####                            NODE ORDERING                            #####
###############################################################################


def default_order(graph):
    """Do nothing, return default (networkx) node order"""
    return graph


def bfs_node_order(graph, start_node=0, start_highest_deg=False):
    """Convert graph's node ordering to BFS"""
    if start_highest_deg:
        start_node = highest_degree_node(graph)
    node_list = list(nx.bfs_tree(graph, source=start_node).nodes())
    return nx.from_numpy_matrix(nx.to_numpy_matrix(graph, nodelist=node_list))


def dfs_node_order(graph, start_node=0, start_highest_deg=False):
    """Convert graph's node ordering to DFS"""
    if start_highest_deg:
        start_node = highest_degree_node(graph)
    node_list = list(nx.dfs_tree(graph, source=start_node).nodes())
    return nx.from_numpy_matrix(nx.to_numpy_matrix(graph, nodelist=node_list))


def reverse_node_order(graph):
    """Reverse the default node order"""
    node_list = sorted(list(range(len(graph))), reverse=True)
    return nx.from_numpy_matrix(nx.to_numpy_matrix(graph, nodelist=node_list))


def random_node_order(graph):
    """Randomise the node order"""
    node_list = list(range(len(graph)))
    random.shuffle(node_list)
    return nx.from_numpy_matrix(nx.to_numpy_matrix(graph, nodelist=node_list))


def highest_degree_node(graph):
    """Get id of node with highest degree"""
    return max(graph.degree(), key=lambda x: x[1])[0]


def highest_degree(graph):
    """Get highest degree value of graph"""
    return max(graph.degree(), key=lambda x: x[1])[1]


###############################################################################
#####                      TEXT CAPTIONS FOR GRAPHS                       #####
###############################################################################


_prefixes = (
    "this is ",
    "the graph is ",
    "this network is ",
    "the graph represents ",
    "it's a synthetic graph, ",
    "a network of ",
    "synthetic graph, ",
    "we see ",
    "here we have ",
    "in this network there's ",
    "we observe ",
    "the following graph is ",
)

_descs_for_num_nodes = (
    "the number of nodes is {}",
    "it has {} nodes in total",
    "the graph contains {} nodes",
    "there are {} nodes in it",
    "it's got {} vertices in total",
    "there's {} vertices in the network",
    "this network's {} vertices in size",
    "in total, this graph has {} nodes",
    "in this network, there's {} nodes",
)

# Attach num nodes desc on the end half the time
_num_nodes_freq = 0.5

# Occasionally don't set a prefix for more variety
_prefix_freq = 0.97


def make_graph_caption(graph, graph_class, **kwargs):
    """Make the full caption for a graph. Calls the below funcs in order

    Args:
        graph_class: The name of the graph type, e.g. 'barbell', 'path'
        class_desc: The text from the graph creation code in the other file
        kwargs: Graph class dependent args, for when they exist

    """

    caption = ""

    # First: Make a prefix
    if random.random() < _prefix_freq:
        caption += make_prefix()

    # Second: Pick a description for the graph class
    class_desc = get_class_description(graph_class)
    size_desc = make_graph_size(graph)
    caption += class_desc.format(size_desc)
    caption += ". "

    # Third: Add a sentence describing some graph features
    caption += describe_graph(graph, graph_class, **kwargs)

    return caption


def make_captions(graph, graph_class, num, **kwargs):
    """Make multiple captions for a graph, trying to not repeat any"""
    max_tries = 50
    captions = []
    for n in range(num):
        for _ in range(max_tries):
            cap = make_graph_caption(graph, graph_class, **kwargs)
            if cap not in captions:
                break
        else:
            print(f"Duplicate caption for a {graph_class} graph")
        captions.append(cap)
    return captions[0] if num == 1 else captions


def make_prefix():
    """ "couple words at the start to introduce the sentence"""
    return random.choice(_prefixes)


def make_graph_size(graph):
    """Cutoffs for describing the size of the graph. Multiple options for each size"""
    s = len(graph)
    cutoffs = (40, 60, 80, 100, 120, 140)
    if s < cutoffs[0]:
        return random.choice(["tiny", "very small", "one of the smallest", "miniature"])
    if s < cutoffs[1]:
        return random.choice(["small", "below average sized", "fairly small"])
    if s < cutoffs[2]:
        return random.choice(["medium sized", "middle sized", "an average size"])
    if s < cutoffs[3]:
        return random.choice(["large", "above average sized"])
    if s < cutoffs[4]:
        return random.choice(["fairly large", "big", "huge"])
    else:
        return random.choice(["gigantic", "giant sized"])


def desc_num_nodes(graph):
    """Describe the exact number of nodes in the graph"""
    desc = random.choice(_descs_for_num_nodes)
    return desc.format(len(graph))


def describe_graph(graph, graph_class, num_desc=2, **kwargs):
    """Describe some features of the graph based on its class"""

    class_fn_map = {
        "barbell": desc_barbell,
        "lollipop": desc_lollipop,
        "clique_ring": desc_clique_ring,
        "2d_grid": desc_square_grid,
        "tri_grid": desc_tri_grid,
        "hex_grid": desc_hex_grid,
        "watts": desc_watts,
        "star": desc_star,
        "path": desc_path,
        "binary": desc_binary,
        "ternary": desc_ternary,
        "starlike": desc_starlike,
        "banana_tree": desc_banana,
        "firecracker": desc_firecracker,
        "sunlet": desc_sunlet,
        "helm": desc_helm,
        "fan": desc_fan,
    }
    fn = class_fn_map.get(graph_class)
    return fn(graph, **kwargs)


def desc_barbell(graph, bell_size, bar_size, **kwargs):
    """For lollipop/barbell. The clique(s) size or path length"""
    barbell_descs = (
        f"the bar length is {bar_size} and the bells both have {bell_size} nodes",
        f"both bells have {bell_size} nodes connected by a bar of {bar_size} nodes",
        f"the {bell_size}-node bells are connected through by a {bar_size}-node bar",
    )
    desc = random.choice(barbell_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_lollipop(graph, head_size, stick_size, **kwargs):
    """For lollipop/barbell. The clique(s) size or path length"""
    lp_descs = (
        f"the head has {head_size} nodes and the stick length is {stick_size}",
        f"it contains a stick of length {stick_size} connected with a head of size {head_size}",
        f"the lollipop head contains {head_size} nodes, while the long stick has {stick_size} nodes",
    )
    desc = random.choice(lp_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_clique_ring(graph, num_cliques, clique_size, **kwargs):
    """For ring of cliques. How many cliques, how big"""
    roc_descs = (
        f"there are {num_cliques} cliques in the ring, each with around {clique_size} nodes",
        f"the {num_cliques} cliques of approx {clique_size} nodes are connected by a bridge to form a ring",
        f"in this ring of {num_cliques} cliques, each has approximately {clique_size} nodes",
    )
    desc = random.choice(roc_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_square_grid(graph, x, y, **kwargs):
    """For square grids. How wide/long it is"""
    square_descs = (
        f"the grid dimensions are {x:2d} nodes wide and {y:2d} nodes high",
        f"this grid has {x:2d} nodes across and {y:2d} nodes up",
        f"it forms a {x}-by-{y} lattice",
    )
    desc = random.choice(square_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_tri_grid(graph, x, y, **kwargs):
    """For triangle grids. How wide/long it is"""
    tri_descs = (
        f"the triangular grid's dimensions are: {x:2d} wide, {y:2d} high",
        f"this triangular lattice is {x:2d} across and {y:2d} up",
        f"it forms a {x} by {y} grid of triangles",
    )
    desc = random.choice(tri_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_hex_grid(graph, x, y, **kwargs):
    """For hexagonal grids. How wide/long it is."""
    hex_descs = (
        f"the hexagonal grid's dimensions are: {x:2d} wide, {y:2d} high",
        f"this lattice consists of {x:2d} hexagons across and {y:2d} up",
        f"it requires {x*y:2d} hexagons: {x:2d} hexagons by {y:2d} ",
    )
    desc = random.choice(hex_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_watts(graph, degs, p, **kwargs):
    """For watts-strogatz. Initial degree of each node, rewiring probs"""
    watts_descs = (
        f"each node is connected to {degs} neighbors and rewiring probability is {p:.03f}",
        f"all nodes begin with degree {degs}, then with probability {p:.03f}, each edge is rewired",
        f"the vertices are joined to their {degs} neighbors then randomly rewired with probability {p:.03f}",
        f"nodes initially connect to the nearest {degs} nodes, then edges are rewired with probability {p:.03f}",
    )
    desc = random.choice(watts_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_star(graph, **kwargs):
    """For stars. How many nodes"""
    star_descs = (
        f"it's a tree where a central node shares an edge with all other nodes",
        f"one node in the middle connects to all other nodes",
    )
    desc = random.choice(star_descs)
    desc += ". " + desc_num_nodes(graph)
    return desc


def desc_path(graph, **kwargs):
    """For path graphs. How long it is"""
    path_descs = (
        "the two nodes on each end have degree one, and all others have degree two",
        "it can be drawn with all vertices and edges lying on a single straight line",
    )
    desc = random.choice(path_descs)
    desc += ". " + desc_num_nodes(graph)
    return desc


def desc_binary(graph, **kwargs):
    """For binary trees. How deep the graph is from the root"""
    depth = max(nx.shortest_path_length(graph, 0).values())
    binary_descs = (
        f"the max depth from the tree root is {depth}",
        f"the maximum depth from the root is {depth}",
        f"the greatest depth node from the root is {depth}",
        f"this tree goes {depth} nodes deep",
    )
    desc = random.choice(binary_descs)
    desc += ". " + desc_num_nodes(graph)
    return desc


def desc_ternary(graph, **kwargs):
    """For ternary trees. How deep the graph is from the root"""
    depth = max(nx.shortest_path_length(graph, 0).values())
    ternary_descs = (
        f"the max depth from the root is {depth}",
        f"the maximum depth from the tree root is {depth}",
        f"the furthest depth node from the root is {depth}",
        f"this tree goes {depth} nodes deep",
    )
    desc = random.choice(ternary_descs)
    desc += ". " + desc_num_nodes(graph)
    return desc


def desc_starlike(graph, num_spokes, spoke_len, **kwargs):
    """For starlike graphs. How many spokes and spoke length"""
    starlike_descs = (
        f"the graph has {num_spokes} spokes of length {spoke_len}",
        f"there are {num_spokes} spokes, each with {spoke_len} nodes",
        f"this star contains {num_spokes} spokes of {spoke_len} nodes",
    )
    desc = random.choice(starlike_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_banana(graph, num_stars, star_size, **kwargs):
    """For banana tree graphs"""
    banana_descs = (
        f"the graph has {num_stars} stars with {star_size} nodes, connected to a root node",
        f"there are {num_stars} star graphs of {star_size} nodes each joined to a root node",
        f"there are {star_size}-node star graphs, {num_stars} in total, joined to the banana tree root",
    )
    desc = random.choice(banana_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_firecracker(graph, num_stars, star_size, **kwargs):
    """For firecracker graphs"""
    fire_descs = (
        f"the firecracker has {num_stars} stars with {star_size} nodes, each connected to another through a leaf node",
        f"there are {num_stars} stars of {star_size} nodes in the firecracker graph",
    )
    desc = random.choice(fire_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_sunlet(graph, cycle_length, **kwargs):
    """For sunlet graphs"""
    sun_descs = (f"the sunlet's cycle length is {cycle_length}",)
    desc = random.choice(sun_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_helm(graph, wheel_size, **kwargs):
    """For helm graphs"""
    helm_descs = (f"the helm's wheel size is {wheel_size}",)
    desc = random.choice(helm_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc


def desc_fan(graph, k_size, path_len, **kwargs):
    """For fan graphs"""
    fan_descs = (f"the fan graph has parameters k={k_size} and path length={path_len}",)
    desc = random.choice(fan_descs)
    if random.random() < _num_nodes_freq:
        desc += ". " + desc_num_nodes(graph)
    return desc
