import py3Dmol
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from rdkit import Chem
import networkx as nx

def view_mol3d(mol, style="stick", width=400, height=400):
    """
    Visualize a 3D RDKit molecule using py3Dmol.
    Use it in Notebook
    """
    if mol.GetNumConformers() == 0:
        raise ValueError("Molecule has no conformer. Generate 3D coordinates first.")

    block = Chem.MolToMolBlock(mol)

    view = py3Dmol.view(width=width, height=height)
    view.addModel(block, 'mol')
    view.setStyle({style: {}})
    view.zoomTo()
    return view.show()


def view_mo2d(mol, atom_types=None):
    """
    Visualize a 2D RDKit molecule using PCA.
    Use it in Notebook
    """

    pos3d = mol.GetConformer().GetPositions()
    pos2d = PCA(n_components=2).fit_transform(pos3d)
    num_points = pos2d.shape[0]

    # Generate a unique color for each point
    cmap = plt.get_cmap("tab20")  # Up to 20 distinct colors
    colors = [cmap(i % 20) for i in range(num_points)]

    plt.figure(figsize=(6, 6))
    for i, (x, y) in enumerate(pos2d):
        plt.scatter(x, y, color=colors[i], s=100, label=str(i))
        label = atom_types[i] if atom_types else str(i)
        plt.text(x, y, label, fontsize=12, ha='center', va='center', color='black', weight='bold')

    plt.axis('equal')
    plt.title("PCA Projection with Index-Based Coloring")
    plt.show()



def view_graph2d(G, node_label_key=None, edge_label_key=None, layout="spring", figsize=(6, 5)):
    """
    Plot a NetworkX graph with optional node and edge labels.

    Args:
        G (nx.Graph): The graph to plot.
        node_label_key (str, optional): Node attribute key to use as label. Default: node index.
        edge_label_key (str, optional): Edge attribute key to display on edges.
        layout (str): Layout algorithm: 'spring', 'kamada', 'shell', 'circular'.
        figsize (tuple): Size of the matplotlib figure.
    """
    # Choose layout
    if layout == "spring":
        pos = nx.spring_layout(G, seed=42)
    elif layout == "kamada":
        pos = nx.kamada_kawai_layout(G)
    elif layout == "shell":
        pos = nx.shell_layout(G)
    elif layout == "circular":
        pos = nx.circular_layout(G)
    else:
        raise ValueError("Unsupported layout")

    # Node labels
    if node_label_key:
        node_labels = nx.get_node_attributes(G, node_label_key)
    else:
        node_labels = {n: n for n in G.nodes()}

    # Draw nodes and edges
    plt.figure(figsize=figsize)
    nx.draw(G, pos,
            with_labels=True,
            labels=node_labels,
            node_color='lightblue',
            node_size=600,
            edge_color='gray',
            font_size=10)

    # Draw edge labels if requested
    if edge_label_key:
        edge_labels = nx.get_edge_attributes(G, edge_label_key)
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red')

    plt.title("NetworkX Molecular Graph")
    plt.axis('off')
    plt.tight_layout()
    plt.show()
