"""
Comprehensive graph generators for 12 different graph classes.

Supports: complete, erdos-renyi, d-regular, watts-strogatz, sbm, delaunay,
euclid-mst, k-partite, grid, torus, hypercube, apollonian
"""

import math
import numpy as np
import networkx as nx
import torch
from scipy.spatial import Delaunay
from scipy.spatial.distance import pdist, squareform
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree


def _giant_component(G: nx.Graph) -> nx.Graph:
    """Extract the largest connected component."""
    comps = sorted(nx.connected_components(G), key=len, reverse=True)
    return G.subgraph(comps[0]).copy()


def _maybe_make_connected(G: nx.Graph, ensure_connected: bool) -> nx.Graph:
    """Optionally extract giant component if graph is disconnected."""
    return _giant_component(G) if ensure_connected and not nx.is_connected(G) else G


def _dense_row_stochastic_from_graph(G: nx.Graph, seed=None, device="cpu", dtype=torch.float64):
    """
    Build a dense row-stochastic transition matrix K from an undirected graph G by
    assigning random positive weights to neighbors and normalizing each row.
    Isolated nodes get a self-loop.
    """
    rng = np.random.default_rng(seed)
    n = G.number_of_nodes()

    # Ensure nodes are 0..n-1 in stable order
    mapping = {u: i for i, u in enumerate(G.nodes())}
    if any(u != mapping[u] for u in G.nodes()):
        G = nx.relabel_nodes(G, mapping, copy=True)

    K = torch.zeros((n, n), dtype=dtype, device=device)
    for i in range(n):
        nbrs = list(G.neighbors(i))
        if not nbrs:
            K[i, i] = 1.0
            continue
        w = rng.random(len(nbrs))
        w /= w.sum()
        for j, wij in zip(nbrs, w):
            K[i, int(j)] = float(wij)

    # Numerical guard: enforce exact row-stochastic
    rowsum = K.sum(dim=1, keepdim=True)
    mask = rowsum.squeeze(1) > 0
    K[mask] = K[mask] / rowsum[mask]
    # For any pathological zero rows, set self-loop
    zero_mask = ~mask
    if torch.any(zero_mask):
        idx = torch.nonzero(zero_mask, as_tuple=True)[0]
        K[idx, idx] = 1.0
    return K


@torch.no_grad()
def _stationary_power_dense(K: torch.Tensor, max_iters: int = 10000, tol: float = 1e-10) -> torch.Tensor:
    """
    Power iteration on K^T to get stationary distribution pi.
    K is dense, row-stochastic.
    """
    n = K.size(0)
    p = torch.full((n,), 1.0 / n, dtype=K.dtype, device=K.device)
    KT = K.transpose(0, 1)
    for _ in range(max_iters):
        p_next = KT @ p
        p_next.clamp_(min=0)
        s = p_next.sum()
        if s <= 0:
            p_next.fill_(1.0 / n)
        else:
            p_next /= s
        if torch.linalg.vector_norm(p_next - p, ord=1) < tol:
            p = p_next
            break
        p = p_next
    return p


def _to_stochastic_and_pi_torch(G: nx.Graph, seed=None, device="cpu", dtype=torch.float64):
    """Convert networkx graph to row-stochastic tensor and stationary distribution."""
    K = _dense_row_stochastic_from_graph(G, seed=seed, device=device, dtype=dtype)
    pi = _stationary_power_dense(K, max_iters=10000, tol=1e-10)
    return K, pi


# ==================== GRAPH GENERATORS ====================

def gen_complete(n, seed=None, device="cpu", dtype=torch.float64, ensure_connected=True):
    """Complete graph: all nodes connected to all others."""
    G = nx.complete_graph(n)
    K, pi = _to_stochastic_and_pi_torch(_maybe_make_connected(G, ensure_connected), seed=seed, device=device, dtype=dtype)
    return K, pi, {"class": "complete", "n": n}


def gen_erdos_renyi(n, p=0.1, seed=None, device="cpu", dtype=torch.float64, ensure_connected=True, max_tries=10):
    """Erdős-Rényi random graph: edges added independently with probability p."""
    rng = np.random.default_rng(seed)
    for _ in range(max_tries):
        G = nx.erdos_renyi_graph(n, p, seed=int(rng.integers(1<<32)))
        G = _maybe_make_connected(G, ensure_connected)
        if not ensure_connected or nx.is_connected(G):
            break
    if ensure_connected and not nx.is_connected(G):
        G = _giant_component(G)
        note = "giant_component"
    else:
        note = None
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    info = {"class": "erdos-renyi", "n": G.number_of_nodes(), "p": p}
    if note:
        info["note"] = note
    return K, pi, info


def gen_d_regular(n, d=4, seed=None, device="cpu", dtype=torch.float64, ensure_connected=True, max_tries=10):
    """d-regular graph: each node has exactly d neighbors."""
    if d < 0 or d >= n:
        raise ValueError("Require 0 <= d < n.")
    if (n * d) % 2 != 0:
        raise ValueError("Feasibility: n*d must be even.")
    rng = np.random.default_rng(seed)
    for _ in range(max_tries):
        G = nx.random_regular_graph(d, n, seed=int(rng.integers(1<<32)))
        if not ensure_connected or nx.is_connected(G):
            break
    if ensure_connected and not nx.is_connected(G):
        G = _giant_component(G)
        note = "giant_component"
    else:
        note = None
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    info = {"class": "d-regular", "n": G.number_of_nodes(), "d": d}
    if note:
        info["note"] = note
    return K, pi, info


def gen_watts_strogatz(n, k=6, beta=0.1, seed=None, device="cpu", dtype=torch.float64, ensure_connected=True):
    """Watts-Strogatz small-world graph: ring lattice with random rewiring."""
    if k % 2 != 0:
        raise ValueError("Watts–Strogatz requires even k.")
    if not (0 <= k < n):
        raise ValueError("Require 0 <= k < n.")
    rng = np.random.default_rng(seed)
    G = nx.watts_strogatz_graph(n, k, beta, seed=int(rng.integers(1<<32)))
    G = _maybe_make_connected(G, ensure_connected)
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    return K, pi, {"class": "watts-strogatz", "n": G.number_of_nodes(), "k": k, "beta": beta}


def gen_sbm(n, sizes=None, P=None, seed=None, device="cpu", dtype=torch.float64, ensure_connected=True):
    """Stochastic block model: community structure with inter/intra-block edge probabilities."""
    if sizes is None:
        sizes = [n//2, n - n//2]
    if sum(sizes) != n:
        raise ValueError("sizes must sum to n")
    k = len(sizes)
    if P is None:
        p_in, p_out = 0.2, 0.02
        P = np.full((k, k), p_out)
        np.fill_diagonal(P, p_in)
    P = np.asarray(P, float)
    if P.shape != (k, k):
        raise ValueError("P must be k x k")
    rng = np.random.default_rng(seed)
    G = nx.stochastic_block_model(sizes, P, seed=int(rng.integers(1<<32)))
    G = nx.Graph(G)
    G = _maybe_make_connected(G, ensure_connected)
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    return K, pi, {"class": "sbm", "n": G.number_of_nodes(), "sizes": list(sizes), "P": P}


def gen_delaunay(n, seed=None, device="cpu", dtype=torch.float64, ensure_connected=True):
    """Delaunay triangulation of random 2D points."""
    rng = np.random.default_rng(seed)
    pts = rng.random((n, 2))
    tri = Delaunay(pts)
    G = nx.Graph()
    G.add_nodes_from(range(n))
    for simplex in tri.simplices:
        for i in range(3):
            u, v = simplex[i], simplex[(i+1) % 3]
            if u != v:
                G.add_edge(int(u), int(v))
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    return K, pi, {"class": "delaunay", "n": n, "pts": pts}


def gen_euclid_mst(n, seed=None, device="cpu", dtype=torch.float64):
    """Euclidean minimum spanning tree of random 2D points."""
    rng = np.random.default_rng(seed)
    pts = rng.random((n, 2))
    D = squareform(pdist(pts, metric="euclidean"))
    W = csr_matrix(D)
    T = minimum_spanning_tree(W)
    # Convert to networkx
    G = nx.from_scipy_sparse_array(T)
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    return K, pi, {"class": "euclid-mst", "n": n, "pts": pts}


def gen_k_partite(n, parts=None, complete=True, p_across=0.2, seed=None, device="cpu", dtype=torch.float64, ensure_connected=True):
    """k-partite graph: nodes divided into k groups, edges only between groups."""
    if parts is None:
        if n % 2 == 0:
            parts = [n//2, n//2]
        else:
            a = n//3
            parts = [a, a, n - 2*a]
    if sum(parts) != n:
        raise ValueError("parts must sum to n")
    k = len(parts)
    offsets = np.cumsum([0] + list(parts))
    groups = [list(range(offsets[i], offsets[i+1])) for i in range(k)]
    G = nx.Graph()
    G.add_nodes_from(range(n))
    rng = np.random.default_rng(seed)
    for i in range(k):
        for j in range(i+1, k):
            if complete:
                for u in groups[i]:
                    for v in groups[j]:
                        G.add_edge(u, v)
            else:
                for u in groups[i]:
                    for v in groups[j]:
                        if rng.random() < p_across:
                            G.add_edge(u, v)
    G = _maybe_make_connected(G, ensure_connected)
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    return K, pi, {"class": "k-partite", "n": G.number_of_nodes(), "parts": list(parts), "complete": complete, "p_across": p_across if not complete else None}


def gen_grid(n, rows=None, cols=None, seed=None, device="cpu", dtype=torch.float64, ensure_connected=True):
    """2D grid graph (4-neighborhood lattice)."""
    if rows is None or cols is None:
        rows = int(math.floor(math.sqrt(n)))
        cols = int(math.ceil(n / rows))
    G = nx.grid_2d_graph(rows, cols)
    # Relabel 2D -> 1D and trim to n
    mapping = {xy: i for i, xy in enumerate(G.nodes())}
    G = nx.relabel_nodes(G, mapping)
    if G.number_of_nodes() > n:
        keep = set(range(n))
        G = G.subgraph(keep).copy()
    G = _maybe_make_connected(G, ensure_connected)
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    return K, pi, {"class": "grid", "rows": rows, "cols": cols, "n": G.number_of_nodes()}


def gen_torus(n, rows=None, cols=None, seed=None, device="cpu", dtype=torch.float64):
    """2D torus graph (grid with periodic boundary conditions)."""
    if rows is None or cols is None:
        rows = int(math.floor(math.sqrt(n)))
        cols = int(round(n / rows))
        rows = max(2, rows)
        cols = max(2, cols)
    G = nx.grid_2d_graph(rows, cols, periodic=True)
    mapping = {xy: i for i, xy in enumerate(G.nodes())}
    G = nx.relabel_nodes(G, mapping)
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    return K, pi, {"class": "torus", "rows": rows, "cols": cols, "n": rows*cols}


def gen_hypercube(n, d=None, seed=None, device="cpu", dtype=torch.float64):
    """Hypercube graph: n = 2^d nodes, each node differs by 1 bit from d neighbors."""
    if d is None:
        d_float = math.log2(n)
        if abs(d_float - round(d_float)) > 1e-9:
            raise ValueError("Hypercube requires n=2^d (or specify d explicitly).")
        d = int(round(d_float))
    else:
        if n != 2**d:
            raise ValueError("If d is given, n must equal 2**d.")
    G = nx.hypercube_graph(d)
    mapping = {node: i for i, node in enumerate(G.nodes())}
    G = nx.relabel_nodes(G, mapping)
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    return K, pi, {"class": "hypercube", "d": d, "n": 2**d}


def gen_apollonian(n, seed=None, device="cpu", dtype=torch.float64):
    """Apollonian network: planar, scale-free graph."""
    rng = np.random.default_rng(seed)
    G = nx.random_apollonian_graph(n, seed=int(rng.integers(1<<32)))
    G = nx.Graph(G)
    K, pi = _to_stochastic_and_pi_torch(G, seed=seed, device=device, dtype=dtype)
    return K, pi, {"class": "apollonian", "n": n}


# ==================== UNIFIED FACTORY ====================

def generate_graph(kind: str, n: int, **kwargs):
    """
    Generate a graph of specified type.
    
    Args:
        kind: Graph type (complete, erdos-renyi, d-regular, watts-strogatz, sbm,
              delaunay, euclid-mst, k-partite, grid, torus, hypercube, apollonian)
        n: Number of nodes
        **kwargs: Type-specific parameters (seed, device, dtype, etc.)
    
    Returns:
        K: (n, n) row-stochastic transition matrix
        pi: (n,) stationary distribution
        info: dict with graph metadata
    """
    kind = kind.lower()
    funcs = {
        "complete": gen_complete,
        "erdos-renyi": gen_erdos_renyi,
        "d-regular": gen_d_regular,
        "watts-strogatz": gen_watts_strogatz,
        "sbm": gen_sbm,
        "delaunay": gen_delaunay,
        "euclid-mst": gen_euclid_mst,
        "k-partite": gen_k_partite,
        "grid": gen_grid,
        "torus": gen_torus,
        "hypercube": gen_hypercube,
        "apollonian": gen_apollonian,
    }
    if kind not in funcs:
        raise ValueError(f"Unknown graph type '{kind}'. Available: {list(funcs.keys())}")
    
    result = funcs[kind](n, **kwargs)
    
    # Normalize return shape
    if isinstance(result, tuple) and len(result) == 3:
        K, pi, info = result
    else:
        raise ValueError(f"Unexpected generator return shape: {type(result)}")
    
    return K, pi, info
