"""
Examples of Factor Graph Construction

This file demonstrates how to construct various factor graphs using
the modular framework, recreating examples from the papers.
"""

import numpy as np
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from core import Variable, Factor, FactorGraph, build_factor_graph



def example_simple_two_variable():
    """
    Recreate the simple two-variable factor graph from Napp & Adams Figure 1a.
    
    Variables: x1, x2 (both binary)
    Factors: ψ1(x1), ψ2(x2), ψ3(x1, x2)
    
    Joint: P(x1, x2) ∝ ψ1(x1)ψ2(x2)ψ3(x1, x2)
    """
    print("=" * 60)
    print("Example: Simple Two-Variable Factor Graph (Napp & Adams Fig 1a)")
    print("=" * 60)
    
    # Method 1: Step-by-step construction
    fg = FactorGraph("TwoVariable")
    
    # Define variables
    x1 = Variable("x1", [0, 1])  # Binary: values 0 and 1
    x2 = Variable("x2", [0, 1])  # Binary: values 0 and 1
    
    fg.add_variables(x1, x2)
    
    # Define factors
    # ψ1(x1): unary factor on x1
    psi1_table = np.array([1.0, 0.5])  # ψ1(0)=1, ψ1(1)=0.5
    psi1 = Factor("psi1", [x1], psi1_table)
    
    # ψ2(x2): unary factor on x2  
    psi2_table = np.array([0.8, 1.0])  # ψ2(0)=0.8, ψ2(1)=1
    psi2 = Factor("psi2", [x2], psi2_table)
    
    # ψ3(x1, x2): binary factor
    # Rows = x1 values, Cols = x2 values
    psi3_table = np.array([
        [1.0, 0.2],   # x1=0: ψ3(0,0)=1, ψ3(0,1)=0.2
        [0.2, 1.0]    # x1=1: ψ3(1,0)=0.2, ψ3(1,1)=1
    ])
    psi3 = Factor("psi3", [x1, x2], psi3_table)
    
    fg.add_factors(psi1, psi2, psi3)
    
    print(f"\nFactor Graph: {fg}")
    print(f"Variables: {fg.variables}")
    print(f"Factors: {fg.factors}")
    print(f"Edges: {fg.get_edges()}")
    print(f"Is tree: {fg.is_tree()}")
    
    # Compute exact marginals
    print("\nExact marginals (by enumeration):")
    Z = fg.compute_partition_function()
    print(f"  Partition function Z = {Z:.4f}")
    
    for var in fg.variables:
        marginal = fg.compute_marginal_exact(var)
        print(f"  P({var.name}) = {marginal}")
    
    return fg


def example_four_variable_tree():
    """
    Recreate the four-variable tree factor graph from Napp & Adams Figure 2a.
    
    Variables: x1, x2 (binary), x3 (ternary), x4 (binary)
    Factors: ψ1(x1), ψ2(x2), ψ3(x3), ψ7(x4), ψ4(x1,x2), ψ5(x2,x3), ψ6(x3,x4)
    """
    print("\n" + "=" * 60)
    print("Example: Four-Variable Tree Factor Graph (Napp & Adams Fig 2a)")
    print("=" * 60)
    
    # Factor tables from Figure 2c
    psi1_table = np.array([1.0, 0.1])
    psi2_table = np.array([1.0, 0.1])
    psi3_table = np.array([2.0, 1.0, 1.0])
    psi7_table = np.array([1.0, 1.0])
    
    psi4_table = np.array([
        [1.0, 0.1],   # x1=0
        [0.1, 3.0]    # x1=1
    ])
    
    psi5_table = np.array([
        [0.1, 2.0, 0.1],  # x2=0
        [3.0, 0.1, 1.0]   # x2=1
    ])
    
    psi6_table = np.array([
        [0.1, 0.1],  # x3=0
        [1.0, 0.1],  # x3=1  
        [0.1, 0.1]   # x3=2 (inferred, not shown in paper)
    ])
    # Note: paper only shows ψ6 for binary x3, let's use ternary
    psi6_table = np.array([
        [0.1, 0.1],
        [1.0, 0.1],
        [0.5, 0.5]
    ])
    
    # Use the convenience function
    variables_spec = {
        "x1": [0, 1],
        "x2": [0, 1],
        "x3": [0, 1, 2],
        "x4": [0, 1],
    }
    
    factors_spec = {
        "psi1": (["x1"], psi1_table),
        "psi2": (["x2"], psi2_table),
        "psi3": (["x3"], psi3_table),
        "psi7": (["x4"], psi7_table),
        "psi4": (["x1", "x2"], psi4_table),
        "psi5": (["x2", "x3"], psi5_table),
        "psi6": (["x3", "x4"], psi6_table),
    }
    
    fg = build_factor_graph(variables_spec, factors_spec, "FourVariableTree")
    
    print(f"\nFactor Graph: {fg}")
    print(f"Is tree: {fg.is_tree()}")
    
    # Show structure
    print("\nGraph structure:")
    for factor in fg.factors:
        neighbors = [v.name for v in fg.neighbors_of_factor(factor)]
        print(f"  {factor.name} -- {neighbors}")
    
    # Compute exact marginals
    print("\nExact marginals:")
    for var in fg.variables:
        marginal = fg.compute_marginal_exact(var)
        print(f"  P({var.name}) = {np.round(marginal, 4)}")
    
    return fg


def example_loopy_graph():
    """
    Create a loopy (cyclic) factor graph from Napp & Adams Figure 2b.
    
    This is a triangle graph with:
    Variables: x1, x2, x3 (all binary)
    Factors forming a cycle: ψ1(x1), ψ2(x1,x2), ψ3(x1,x3), ψ4(x2,x3), ψ5(x2,x3), ψ6(x3)
    
    Simplified version with just the cycle structure:
    """
    print("\n" + "=" * 60)
    print("Example: Loopy Factor Graph (Napp & Adams Fig 2b)")
    print("=" * 60)
    
    fg = FactorGraph("LoopyGraph")
    
    # Binary variables
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    x3 = fg.add_variable(Variable("x3", [0, 1]))
    
    # Unary factors (leaf factors for conditioning)
    fg.add_factor(Factor("psi1", [x1], np.array([1.0, 1.0])))
    
    # Pairwise factors creating a cycle
    # Anti-correlating factors (prefer different values)
    anti_corr = np.array([
        [0.1, 1.0],
        [1.0, 0.1]
    ])
    
    fg.add_factor(Factor("psi2", [x1, x2], anti_corr.copy()))
    fg.add_factor(Factor("psi3", [x1, x3], anti_corr.copy()))
    fg.add_factor(Factor("psi4", [x2, x3], anti_corr.copy()))
    
    print(f"\nFactor Graph: {fg}")
    print(f"Is tree: {fg.is_tree()}")  # Should be False
    
    print("\nGraph structure (forms a triangle/cycle):")
    for factor in fg.factors:
        neighbors = [v.name for v in fg.neighbors_of_factor(factor)]
        print(f"  {factor.name} -- {neighbors}")
    
    # Exact marginals
    print("\nExact marginals:")
    for var in fg.variables:
        marginal = fg.compute_marginal_exact(var)
        print(f"  P({var.name}) = {np.round(marginal, 4)}")
    
    return fg


def example_poset_structure():
    """
    Demonstrate the poset structure of a factor graph.
    """
    print("\n" + "=" * 60)
    print("Example: Poset Structure of Factor Graphs")
    print("=" * 60)
    
    # Create a simple factor graph
    fg = FactorGraph("PosetDemo")
    
    x1 = fg.add_variable(Variable("x1", [0, 1]))
    x2 = fg.add_variable(Variable("x2", [0, 1]))
    
    fg.add_factor(Factor("f1", [x1], np.array([1.0, 1.0])))
    fg.add_factor(Factor("f2", [x1, x2], np.ones((2, 2))))
    
    # Convert to poset
    poset = fg.to_poset()
    
    print(f"\nFactor Graph: {fg}")
    print(f"\nAssociated Poset:")
    print(poset)
    
    print(f"\nPoset properties:")
    print(f"  Number of elements: {len(poset)}")
    print(f"  Maximum chain length: {poset.chain_length()}")
    
    # Check for linear/colinear points
    linear = poset.find_linear_points()
    colinear = poset.find_colinear_points()
    
    print(f"\n  Linear points: {linear}")
    print(f"  Colinear points: {colinear}")
    
    # If there's a linear point, demonstrate retraction
    if linear:
        element, a_up = linear[0]
        print(f"\n  Retracting linear point {element} to {a_up}...")
        reduced_poset = poset.retract_linear_point(element)
        print(f"  Reduced poset: {reduced_poset}")
    
    # Compute core
    core, retractions = poset.compute_core()
    print(f"\nCore of poset: {core}")
    print(f"Retraction sequence: {retractions}")
    
    return fg, poset


def example_custom_factor_graph():
    """
    Show how users can easily specify their own factor graphs.
    """
    print("\n" + "=" * 60)
    print("Example: Custom Factor Graph Specification")
    print("=" * 60)
    
    # User-friendly specification
    print("\nUser specifies:")
    print("""
    variables = {
        "x1": [0, 1],           # Binary
        "x2": [0, 1, 2],        # Ternary  
        "x3": [0, 1],           # Binary
    }
    
    factors = {
        "prior_x1": (["x1"], array([0.7, 0.3])),
        "prior_x2": (["x2"], array([0.5, 0.3, 0.2])),
        "edge_12": (["x1", "x2"], <3x2 array>),
        "edge_23": (["x2", "x3"], <2x3 array>),
    }
    """)
    
    # Actual construction
    variables_spec = {
        "x1": [0, 1],
        "x2": [0, 1, 2],
        "x3": [0, 1],
    }
    
    factors_spec = {
        "prior_x1": (["x1"], np.array([0.7, 0.3])),
        "prior_x2": (["x2"], np.array([0.5, 0.3, 0.2])),
        "edge_12": (["x1", "x2"], np.array([
            [1.0, 0.5, 0.2],  # x1=0
            [0.2, 0.5, 1.0]   # x1=1
        ])),
        "edge_23": (["x2", "x3"], np.array([
            [1.0, 0.1],  # x2=0
            [0.5, 0.5],  # x2=1
            [0.1, 1.0]   # x2=2
        ])),
    }
    
    fg = build_factor_graph(variables_spec, factors_spec, "CustomGraph")
    
    print(f"\nConstructed Factor Graph: {fg}")
    print(f"Is tree: {fg.is_tree()}")
    
    print("\nFactor details:")
    for factor in fg.factors:
        print(f"\n  {factor.name}:")
        print(f"    Scope: {[v.name for v in factor.variables]}")
        print(f"    Table shape: {factor.table.shape}")
        print(f"    Table:\n{factor.table}")
    
    print("\nExact marginals:")
    for var in fg.variables:
        marginal = fg.compute_marginal_exact(var)
        print(f"  P({var.name}) = {np.round(marginal, 4)}")
    
    return fg


if __name__ == "__main__":
    # Run all examples
    fg1 = example_simple_two_variable()
    fg2 = example_four_variable_tree()
    fg3 = example_loopy_graph()
    fg4, poset4 = example_poset_structure()
    fg5 = example_custom_factor_graph()
    
    print("\n" + "=" * 60)
    print("All examples completed successfully!")
    print("=" * 60)
