import os

import matplotlib.pyplot as plt
import networkx as nx


def visualize_graph(
    G,
    layout="spring",
    directed=False,
    seed=None,
    save_path="datasets/temp/visual/spring.png",
):
    """
    Visualizes the graph using different layout options and optionally saves the image.

    Parameters:
    - G: networkx.Graph (or DiGraph)
    - layout: str, one of ["spring", "spectral", "circular", "kamada_kawai", "random"]
    - directed: bool, whether to draw the graph as directed
    - seed: int, seed for reproducibility in layouts
    - save_path: str, path in datasets/ to save the plot (e.g., save_path="datasets/temp/visual/spring.png")

    Returns:
    - Displays the graph visualization.
    - Saves the visualization if save_path is provided.
    """

    # Select the layout
    if layout == "spring":
        pos = nx.spring_layout(G, seed=seed)
    elif layout == "spectral":
        pos = nx.spectral_layout(G)
    elif layout == "circular":
        pos = nx.circular_layout(G)
    elif layout == "kamada_kawai":
        pos = nx.kamada_kawai_layout(G)
    elif layout == "random":
        pos = nx.random_layout(G, seed=seed)
    else:
        raise ValueError(
            "Invalid layout. Choose from 'spring', 'spectral', 'circular', 'kamada_kawai', or 'random'."
        )

    # Get node colors (if they exist)
    node_colors = [G.nodes[node].get("color", "grey") for node in G.nodes]

    # Draw the graph
    plt.figure(figsize=(5, 5))

    if directed:
        nx.draw(
            G,
            pos,
            with_labels=False,
            node_color=node_colors,
            edge_color="black",
            node_size=300,
            font_size=8,
            arrows=True,
        )
    else:
        nx.draw(
            G,
            pos,
            with_labels=False,
            node_color=node_colors,
            edge_color="black",
            node_size=300,
            font_size=8,
        )

    # Save the figure if a save path is provided
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, format="png", dpi=300)
        plt.close()  # This forces matplotlib to free up the figure
