import random
import networkx as nx


def generate_random_graph(n, p=0.3, seed=None):
    """
    Generates a random Erdős-Rényi graph with n nodes.
    Each possible edge is included with probability p.
    All nodes are initially colored grey.

    Properties:
    - connected: maybe
    - acyclic: maybe
    - cyclic: maybe
    - bipartite: maybe
    - has_degree_{N}: maybe for various N

    Parameters:
    - n (int): Number of nodes in the graph.
    - p (float): Probability of edge creation (default: 0.3).
    - seed (int, optional): Random seed for reproducibility.

    Returns:
    - G (networkx.Graph): Generated random graph.
    """
    G = nx.erdos_renyi_graph(n, p, seed=seed)  # Random graph model
    nx.set_node_attributes(G, "grey", "color")  # Set all nodes to grey
    return G


def generate_random_connected_graph(n, p=0.3, seed=None):
    """
    Generates a random connected graph with n nodes.

    Properties:
    - connected: True
    - acyclic: maybe
    - cyclic: maybe
    - bipartite: maybe
    - has_degree_{N}: maybe for various N

    Parameters:
    - n (int): Number of nodes.
    - p (float): Probability of rewiring edges (default: 0.3).
    - seed (int, optional): Random seed for reproducibility.

    Returns:
    - G (networkx.Graph): Random connected graph.
    """

    # Generate a connected Watts-Strogatz small-world graph
    G = nx.connected_watts_strogatz_graph(n, 2, p, seed=seed)

    # Set all nodes to grey
    nx.set_node_attributes(G, "grey", "color")

    return G


def generate_random_tree(n, p=0.3, seed=None):
    """
    Generates a random tree by extracting a BFS spanning tree from a connected
    Watts–Strogatz small-world graph with n nodes.

    Properties:
    - connected: True
    - acyclic: True
    - tree: True
    - cyclic: False
    - has_degree_1: True (trees always have leaves)

    Parameters:
    - n (int): Number of nodes (must be at least 2).
    - p (float): Rewiring probability (default: 0.3).
    - seed (int, optional): Random seed.

    Returns:
    - T (networkx.Graph): Undirected tree with all nodes colored grey.
    """

    # Generate connected Watts-Strogatz graph
    G_ws = nx.connected_watts_strogatz_graph(n, 2, p, seed=seed)

    # Use BFS to extract a spanning tree
    rnd = random.Random(seed)
    source = rnd.choice(list(G_ws.nodes))

    T = nx.bfs_tree(G_ws, source)
    T = nx.Graph(T)  # Convert to undirected

    # Set all nodes to grey
    nx.set_node_attributes(T, "grey", "color")

    return T


def generate_random_2_component_graph(n, p=0.3, seed=None, color1="blue", color2="orange"):
    """
    Generates a random disconnected graph with n nodes, consisting of two connected components.
    One randomly chosen node in each component is colored.

    The function splits the n nodes into two groups (each having at least k+1 nodes),
    generates two connected graphs using the Watts–Strogatz model,
    and then combines them using a disjoint union.

    Properties:
    - connected: False
    - has_components: 2
    - has_colored_node: True

    Parameters:
    - n (int): Total number of nodes (must be at least 2*(k+1)).
    - p (float): Probability of rewiring edges (default: 0.3).
    - seed (int, optional): Random seed for reproducibility.
    - color1 (str): Color for a node in the first component (default: "blue").
    - color2 (str): Color for a node in the second component (default: "orange").

    Returns:
    - G (networkx.Graph): A random disconnected graph with two connected components.
    - blue_node (int): A node in the first component that is colored with color1.
    - orange_node (int): A node in the second component that is colored with color2.
    """

    # Dynamically adjust k based on graph size
    if n < 6:
        k = 1  # For small graphs, use k=1 which requires only 2 nodes per component
    else:
        k = 2  # Default value for larger graphs

    min_nodes = k + 1
    if n < 2 * min_nodes:
        raise ValueError(
            f"n must be at least {2 * min_nodes} to generate two connected components with k={k}"
        )

    # Use a local random generator for reproducibility.
    rnd = random.Random(seed)
    # Split n into two parts, each with at least min_nodes.
    n1 = rnd.randint(min_nodes, n - min_nodes)
    n2 = n - n1
    print(f"Component 1 will have {n1} nodes; Component 2 will have {n2} nodes.")

    # Generate distinct seeds for each component if a seed is provided.
    seed1 = seed + 1 if seed is not None else None
    seed2 = seed + 2 if seed is not None else None

    # Generate two connected graphs using the Watts–Strogatz model.
    G1 = nx.connected_watts_strogatz_graph(n1, 2, p, seed=seed1)
    G2 = nx.connected_watts_strogatz_graph(n2, 2, p, seed=seed2)

    # Combine the two graphs into one disconnected graph.
    G = nx.disjoint_union(G1, G2)

    # Set all nodes to grey initially.
    nx.set_node_attributes(G, "grey", "color")

    # Choose one random node from each component:
    blue_node = rnd.choice(list(range(0, n1)))
    orange_node = rnd.choice(list(range(n1, n1 + n2)))

    G.nodes[blue_node]["color"] = color1
    G.nodes[orange_node]["color"] = color2
    print(f"Node {blue_node} in component 1 is colored {color1}.")
    print(f"Node {orange_node} in component 2 is colored {color2}.")

    return G, blue_node, orange_node


def generate_star_graph(n, p=None, seed=None):  # pylint: disable=unused-argument
    """
    Generates a star graph with n nodes. The center node is randomly relabeled to avoid
    always being node 0. All nodes are initially colored grey.

    Properties:
    - connected: True
    - acyclic: True
    - tree: True
    - has_center: True
    - has_leaves: True
    - has_degree_1: True
    - has_degree_{n-1}: True (center node)

    Parameters:
    - n (int): Number of nodes (1 center + n-1 leaves)
    - p (ignored): Included for compatibility
    - seed (int, optional): Random seed for reproducibility

    Returns:
    - G (networkx.Graph): Star graph with scrambled node labels
    """

    # Create star graph: center is node 0, leaves are 1 through n-1
    G = nx.star_graph(n - 1)

    # Relabel nodes randomly using seed
    rnd = random.Random(seed)
    shuffled_labels = list(G.nodes())
    rnd.shuffle(shuffled_labels)
    mapping = dict(zip(G.nodes(), shuffled_labels))
    G = nx.relabel_nodes(G, mapping)

    # Set all nodes to grey
    nx.set_node_attributes(G, "grey", "color")

    return G


def generate_random_bipartite_graph(n, p=0.3, seed=None):
    """
    Generates a connected bipartite graph with n nodes split into two equal partitions.
    The final node labels are shuffled to remove partition ID predictability.

    Properties:
    - connected: True
    - bipartite: True

    Parameters:
    - n (int): Total number of nodes.
    - p (float): Edge probability for bipartite connections.
    - seed (int, optional): Random seed.

    Returns:
    - G (networkx.Graph): A connected bipartite graph with scrambled node IDs.
    - part0 (list): Node labels of partition 0.
    - part1 (list): Node labels of partition 1.
    """

    rnd = random.Random(seed)

    # Original partitions before shuffling
    original_part0 = list(range(n // 2))
    original_part1 = list(range(n // 2, n))

    G = nx.Graph()
    G.add_nodes_from(original_part0, bipartite=0)
    G.add_nodes_from(original_part1, bipartite=1)

    for u in original_part0:
        for v in original_part1:
            if rnd.random() < p:
                G.add_edge(u, v)

    while not nx.is_connected(G):
        u = rnd.choice(original_part0)
        v = rnd.choice(original_part1)
        G.add_edge(u, v)

    # Scramble node labels
    all_nodes = original_part0 + original_part1
    shuffled_nodes = all_nodes[:]
    rnd.shuffle(shuffled_nodes)

    mapping = dict(zip(all_nodes, shuffled_nodes))
    G = nx.relabel_nodes(G, mapping)

    # Remap partitions based on new labels
    part0 = [mapping[u] for u in original_part0]
    part1 = [mapping[u] for u in original_part1]

    nx.set_node_attributes(G, "grey", "color")

    return G, part0, part1
