"""
Benchmark Factor Graph Generators

Generates various factor graph families for stress testing:
1. Chains with varying length
2. Trees with varying depth/branching
3. Loopy cores with tendrils
4. Grids with pruned leaves
5. Random factor graphs with planted cores
"""

import numpy as np
from typing import Tuple, List, Optional
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph


def random_binary_table(shape: Tuple[int, ...], seed: Optional[int] = None) -> np.ndarray:
    """Generate a random positive table."""
    if seed is not None:
        np.random.seed(seed)
    return np.random.rand(*shape) * 0.9 + 0.1  # Avoid zeros


def generate_chain(n_vars: int, seed: int = 42) -> FactorGraph:
    """
    Generate a chain factor graph: x0 -- x1 -- x2 -- ... -- x_{n-1}
    
    Includes unary factors at endpoints.
    """
    np.random.seed(seed)
    fg = FactorGraph(f"chain_{n_vars}")
    
    # Create variables
    vars = []
    for i in range(n_vars):
        v = fg.add_variable(Variable(f"x{i}", [0, 1]))
        vars.append(v)
    
    # Unary factors at endpoints
    fg.add_factor(Factor("u_start", [vars[0]], random_binary_table((2,))))
    fg.add_factor(Factor("u_end", [vars[-1]], random_binary_table((2,))))
    
    # Pairwise factors
    for i in range(n_vars - 1):
        fg.add_factor(Factor(f"f{i}_{i+1}", [vars[i], vars[i+1]], random_binary_table((2, 2))))
    
    return fg


def generate_binary_tree(depth: int, seed: int = 42) -> FactorGraph:
    """
    Generate a complete binary tree factor graph.
    
    depth=1: 1 node (root only)
    depth=2: 3 nodes (root + 2 children)
    depth=3: 7 nodes
    etc.
    """
    np.random.seed(seed)
    fg = FactorGraph(f"tree_depth{depth}")
    
    n_nodes = 2**depth - 1
    vars = []
    
    # Create all variables
    for i in range(n_nodes):
        v = fg.add_variable(Variable(f"x{i}", [0, 1]))
        vars.append(v)
    
    # Add edges (parent to children)
    for i in range(n_nodes):
        left_child = 2*i + 1
        right_child = 2*i + 2
        
        if left_child < n_nodes:
            fg.add_factor(Factor(f"f{i}_{left_child}", [vars[i], vars[left_child]], 
                                random_binary_table((2, 2))))
        if right_child < n_nodes:
            fg.add_factor(Factor(f"f{i}_{right_child}", [vars[i], vars[right_child]], 
                                random_binary_table((2, 2))))
    
    # Unary factors at leaves
    leaf_start = 2**(depth-1) - 1
    for i in range(leaf_start, n_nodes):
        fg.add_factor(Factor(f"u{i}", [vars[i]], random_binary_table((2,))))
    
    return fg


def generate_loopy_core_with_tendrils(core_size: int, tendril_length: int, 
                                       n_tendrils_per_node: int = 1,
                                       seed: int = 42) -> FactorGraph:
    """
    Generate a loopy core (cycle) with chain tendrils attached.
    
    Args:
        core_size: Number of nodes in the cycle (≥3)
        tendril_length: Length of each tendril chain
        n_tendrils_per_node: Number of tendrils per core node
    """
    np.random.seed(seed)
    fg = FactorGraph(f"loopy_core{core_size}_tendril{tendril_length}")
    
    # Create core variables (cycle)
    core_vars = []
    for i in range(core_size):
        v = fg.add_variable(Variable(f"c{i}", [0, 1]))
        core_vars.append(v)
    
    # Add cycle edges
    for i in range(core_size):
        j = (i + 1) % core_size
        fg.add_factor(Factor(f"fc{i}_{j}", [core_vars[i], core_vars[j]], 
                            random_binary_table((2, 2))))
    
    # Add tendrils to each core node
    tendril_idx = 0
    for core_idx, core_var in enumerate(core_vars):
        for t in range(n_tendrils_per_node):
            # Create tendril chain
            prev_var = core_var
            for depth in range(tendril_length):
                new_var = fg.add_variable(Variable(f"t{tendril_idx}_{depth}", [0, 1]))
                fg.add_factor(Factor(f"ft{tendril_idx}_{depth}", [prev_var, new_var],
                                    random_binary_table((2, 2))))
                prev_var = new_var
            
            # Unary at tendril leaf
            fg.add_factor(Factor(f"ut{tendril_idx}", [prev_var], random_binary_table((2,))))
            tendril_idx += 1
    
    return fg


def generate_grid_with_pruned_leaves(rows: int, cols: int, 
                                      prune_fraction: float = 0.3,
                                      seed: int = 42) -> FactorGraph:
    """
    Generate a grid graph with some leaf nodes pruned (given unary factors).
    
    Args:
        rows, cols: Grid dimensions
        prune_fraction: Fraction of boundary nodes to add unary factors to
    """
    np.random.seed(seed)
    fg = FactorGraph(f"grid_{rows}x{cols}")
    
    # Create variables
    vars = {}
    for i in range(rows):
        for j in range(cols):
            v = fg.add_variable(Variable(f"x{i}_{j}", [0, 1]))
            vars[(i, j)] = v
    
    # Add grid edges
    for i in range(rows):
        for j in range(cols):
            # Right neighbor
            if j < cols - 1:
                fg.add_factor(Factor(f"fh{i}_{j}", [vars[(i,j)], vars[(i,j+1)]],
                                    random_binary_table((2, 2))))
            # Down neighbor
            if i < rows - 1:
                fg.add_factor(Factor(f"fv{i}_{j}", [vars[(i,j)], vars[(i+1,j)]],
                                    random_binary_table((2, 2))))
    
    # Add unary factors to some boundary nodes
    boundary = []
    for i in range(rows):
        boundary.append((i, 0))
        boundary.append((i, cols-1))
    for j in range(1, cols-1):
        boundary.append((0, j))
        boundary.append((rows-1, j))
    
    n_unary = int(len(boundary) * prune_fraction)
    unary_nodes = np.random.choice(len(boundary), size=n_unary, replace=False)
    
    for idx, node_idx in enumerate(unary_nodes):
        i, j = boundary[node_idx]
        fg.add_factor(Factor(f"u{idx}", [vars[(i,j)]], random_binary_table((2,))))
    
    return fg


def generate_random_with_planted_core(n_total: int, core_size: int, 
                                       edge_density: float = 0.3,
                                       seed: int = 42) -> FactorGraph:
    """
    Generate a random factor graph with a planted dense core.
    
    Args:
        n_total: Total number of variables
        core_size: Number of variables in the dense core
        edge_density: Fraction of possible edges to include
    """
    np.random.seed(seed)
    fg = FactorGraph(f"random_{n_total}_core{core_size}")
    
    # Create variables
    vars = []
    for i in range(n_total):
        v = fg.add_variable(Variable(f"x{i}", [0, 1]))
        vars.append(v)
    
    # Core is fully connected
    core_vars = vars[:core_size]
    edge_count = 0
    for i in range(core_size):
        for j in range(i+1, core_size):
            fg.add_factor(Factor(f"fcore{edge_count}", [core_vars[i], core_vars[j]],
                                random_binary_table((2, 2))))
            edge_count += 1
    
    # Peripheral nodes connect to core with some probability
    for i in range(core_size, n_total):
        # Connect to 1-3 core nodes
        n_connections = np.random.randint(1, min(4, core_size+1))
        connected_to = np.random.choice(core_size, size=n_connections, replace=False)
        
        for c in connected_to:
            fg.add_factor(Factor(f"f{i}_{c}", [vars[i], vars[c]],
                                random_binary_table((2, 2))))
        
        # Add unary factor to peripheral node
        fg.add_factor(Factor(f"u{i}", [vars[i]], random_binary_table((2,))))
    
    return fg


def compute_graph_stats(fg: FactorGraph) -> dict:
    """Compute statistics about a factor graph."""
    n_vars = fg.num_variables
    n_factors = fg.num_factors
    
    # Count directed incidences (var -> fac and fac -> var)
    n_directed_edges = 0
    for factor in fg.factors:
        n_directed_edges += 2 * len(factor.variables)  # Both directions
    
    # Estimate CRN size (Napp-style)
    # Bundles ≈ directed incidences
    n_bundles = n_directed_edges
    
    # Species per bundle ≈ arity + 1 (for unassigned)
    avg_arity = 2  # Assuming binary variables
    n_species_estimate = n_bundles * (avg_arity + 1) + n_vars * (avg_arity + 1)
    
    # Reactions: recycling + sum messages + product messages + marginals
    n_recycling = n_bundles * avg_arity + n_vars * avg_arity
    n_sum_reactions = sum(
        np.prod([v.cardinality for v in f.variables]) 
        for f in fg.factors
    )
    n_product_reactions = n_vars * avg_arity  # Simplified estimate
    
    return {
        'n_variables': n_vars,
        'n_factors': n_factors,
        'n_directed_edges': n_directed_edges,
        'n_bundles': n_bundles,
        'n_species_estimate': n_species_estimate,
        'n_reactions_estimate': n_recycling + n_sum_reactions + n_product_reactions,
    }


if __name__ == "__main__":
    # Test generators
    print("Testing graph generators...\n")
    
    generators = [
        ("Chain(10)", lambda: generate_chain(10)),
        ("BinaryTree(4)", lambda: generate_binary_tree(4)),
        ("LoopyCore(4)+Tendrils(3)", lambda: generate_loopy_core_with_tendrils(4, 3)),
        ("Grid(4x4)", lambda: generate_grid_with_pruned_leaves(4, 4)),
        ("Random(20,core=5)", lambda: generate_random_with_planted_core(20, 5)),
    ]
    
    for name, gen_func in generators:
        fg = gen_func()
        stats = compute_graph_stats(fg)
        print(f"{name}:")
        print(f"  Variables: {stats['n_variables']}")
        print(f"  Factors: {stats['n_factors']}")
        print(f"  Directed edges: {stats['n_directed_edges']}")
        print(f"  Est. species: {stats['n_species_estimate']}")
        print(f"  Est. reactions: {stats['n_reactions_estimate']}")
        print()
