# benchmarks/draw_factor_graph.py
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import networkx as nx


@dataclass
class SimpleFG:
    """Fallback container if you want to draw from raw pieces."""
    variables: List[str]
    factors: List[str]
    # edges are (factor, variable)
    edges: List[Tuple[str, str]]


def _extract_fg(fg) -> SimpleFG:
    """
    Try to extract variables/factors/edges from your project's FactorGraph.
    Fallback: allow passing a SimpleFG directly.
    """
    if isinstance(fg, SimpleFG):
        return fg

    # Try common shapes in your repo:
    # fg.variables: list of Variable objects with .name
    # fg.factors: list of Factor objects with .name and .neighbors or .scope
    var_names = []
    fac_names = []
    edges: List[Tuple[str, str]] = []

    # Variables
    if hasattr(fg, "variables"):
        for v in fg.variables:
            var_names.append(getattr(v, "name", str(v)))
    elif hasattr(fg, "num_variables"):
        raise ValueError("FactorGraph has num_variables but no .variables list; add an accessor.")

    # Factors
    if hasattr(fg, "factors"):
        for f in fg.factors:
            fac_name = getattr(f, "name", str(f))
            fac_names.append(fac_name)

            # Try to get its incident variables
            neigh = None
            for attr in ("neighbors", "scope", "variables", "vars"):
                if hasattr(f, attr):
                    neigh = getattr(f, attr)
                    break

            if neigh is None:
                # try dict-like factor table keyed by variable names
                raise ValueError(f"Cannot infer neighbors for factor {fac_name}. "
                                 f"Expected attribute like .neighbors or .scope.")

            # normalize neighbor names
            for u in neigh:
                u_name = getattr(u, "name", str(u))
                edges.append((fac_name, u_name))
    else:
        raise ValueError("Expected fg.factors")

    # Deduplicate
    var_names = list(dict.fromkeys(var_names))
    fac_names = list(dict.fromkeys(fac_names))
    edges = list(dict.fromkeys(edges))

    return SimpleFG(var_names, fac_names, edges)


def draw_factor_graph_napp_style(
    fg,
    *,
    ax=None,
    with_messages: bool = True,
    factor_label_style: str = "psi",   # "psi" -> ψ_1, ψ_2 ... ; "name" -> use factor names
    node_size_var: int = 1200,
    node_size_fac: int = 900,
    edge_width: float = 1.8,
    message_color: str = "0.5",        # gray
    edge_color: str = "0.15",          # near-black
    font_size: int = 12,
    seed: int = 7,
):
    """
    Draw a bipartite factor graph with Napp–Adams-like styling:
      - variables: circles
      - factors: squares
      - edges: black
      - optional gray directed message arrows (two per undirected edge)

    Returns (fig, ax).
    """
    sfg = _extract_fg(fg)

    B = nx.Graph()
    for v in sfg.variables:
        B.add_node(v, bipartite=0, kind="var")
    for a in sfg.factors:
        B.add_node(a, bipartite=1, kind="fac")
    B.add_edges_from(sfg.edges)

    # Layout: force left-right bipartite look (variables on left, factors on right)
    # If you want the exact Napp feel (factors above/around), adjust manually later.
    pos = nx.bipartite_layout(B, sfg.variables, align="vertical", scale=2.2)

    # Slight horizontal separation
    for n, (x, y) in pos.items():
        if n in sfg.variables:
            pos[n] = (x - 0.6, y)
        else:
            pos[n] = (x + 0.6, y)

    if ax is None:
        fig, ax = plt.subplots(figsize=(6.5, 3.2))
    else:
        fig = ax.figure

    ax.set_axis_off()

    # Draw undirected structure edges
    nx.draw_networkx_edges(
        B, pos, ax=ax,
        width=edge_width,
        edge_color=edge_color,
        alpha=1.0
    )

    # Draw variable nodes (circles)
    nx.draw_networkx_nodes(
        B, pos, nodelist=sfg.variables, ax=ax,
        node_shape="o",
        node_size=node_size_var,
        node_color="white",
        edgecolors=edge_color,
        linewidths=2.0
    )

    # Draw factor nodes (squares)
    nx.draw_networkx_nodes(
        B, pos, nodelist=sfg.factors, ax=ax,
        node_shape="s",
        node_size=node_size_fac,
        node_color="white",
        edgecolors=edge_color,
        linewidths=2.0
    )

    # Labels: x_i for variables, ψ_j for factors (or names)
    var_labels = {}
    for i, v in enumerate(sfg.variables, start=1):
        # if already looks like x1, keep it
        if str(v).startswith("x"):
            var_labels[v] = str(v)
        else:
            var_labels[v] = f"$x_{i}$"

    if factor_label_style == "name":
        fac_labels = {a: str(a) for a in sfg.factors}
    else:
        fac_labels = {a: f"$\\psi_{j}$" for j, a in enumerate(sfg.factors, start=1)}

    nx.draw_networkx_labels(B, pos, labels=var_labels, ax=ax, font_size=font_size)
    nx.draw_networkx_labels(B, pos, labels=fac_labels, ax=ax, font_size=font_size)

    # Optional message arrows (gray) and labels S/P near edges
    if with_messages:
        # Draw arrows in both directions for each edge, offset slightly
        for (a, v) in sfg.edges:
            # arrow: v -> a labeled P^(v->a)
            _draw_directed_edge(ax, pos[v], pos[a], color=message_color, rad=0.12)
            _place_edge_text(ax, pos[v], pos[a], text="$P$", color=message_color, t=0.45, dy=0.04)

            # arrow: a -> v labeled S^(a->v)
            _draw_directed_edge(ax, pos[a], pos[v], color=message_color, rad=-0.12)
            _place_edge_text(ax, pos[a], pos[v], text="$S$", color=message_color, t=0.55, dy=-0.04)

    return fig, ax


def _draw_directed_edge(ax, p0, p1, color="0.6", rad=0.12):
    """Curved directed arrow from p0 to p1 (matplotlib annotate)."""
    ax.annotate(
        "",
        xy=p1, xytext=p0,
        arrowprops=dict(
            arrowstyle="-|>",
            color=color,
            lw=1.3,
            shrinkA=18, shrinkB=18,
            connectionstyle=f"arc3,rad={rad}",
        ),
        zorder=1
    )


def _place_edge_text(ax, p0, p1, text="S", color="0.6", t=0.5, dy=0.0):
    """Place small label near the middle of edge p0->p1."""
    x = (1 - t) * p0[0] + t * p1[0]
    y = (1 - t) * p0[1] + t * p1[1] + dy
    ax.text(x, y, text, color=color, fontsize=11, ha="center", va="center",
            bbox=dict(boxstyle="round,pad=0.12", facecolor="white", edgecolor="none", alpha=0.8))


if __name__ == "__main__":
    # Minimal demo (raw):
    demo = SimpleFG(
        variables=["x1", "x2"],
        factors=["psi1", "psi2", "psi3"],
        edges=[("psi1", "x1"), ("psi2", "x2"), ("psi3", "x1"), ("psi3", "x2")]
    )
    fig, ax = draw_factor_graph_napp_style(demo, with_messages=True, factor_label_style="psi")
    plt.show()
