"""
Example showing how to use predefined graphs with CausalProfiler.

In this example we load a graph structure from a YAML file to initialize the space of interest.
We then generate data and queries from the predefined graph
"""

from causal_profiler import CausalProfiler, SpaceOfInterest
from causal_profiler.constants import (
    MechanismFamily,
    QueryType,
    VariableDataType,
)
import argparse
import random
import numpy as np
import torch

# Seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def main(predefined_graph_file):
    print("=" * 70)
    print("CausalProfiler: Predefined Graph Example")
    print("=" * 70)
    print()

    # Define the space of interest with a predefined graph
    space = SpaceOfInterest(
        # Predefined graph configuration
        predefined_graph_file=predefined_graph_file,
        # Variable configuration (applies to all nodes in the graph)
        variable_type=VariableDataType.DISCRETE,
        number_of_categories=2,  # Binary variables
        # Mechanism configuration
        mechanism_family=MechanismFamily.TABULAR,
        # Query configuration
        number_of_queries=3,
        query_type=QueryType.ATE,
        # Data configuration
        number_of_data_points=1000,
    )

    print("Space of Interest Configuration:")
    print("-" * 70)
    print(f"Predefined Graph File: {space.predefined_graph_file}")
    print(f"Variable Type: {space.variable_type}")
    print(f"Number of Categories: {space.number_of_categories}")
    print(f"Mechanism Family: {space.mechanism_family}")
    print(f"Query Type: {space.query_type}")
    print(f"Number of Queries: {space.number_of_queries}")
    print(f"Number of Data Points: {space.number_of_data_points}")
    print()

    # Create a CausalProfiler with the space of interest
    profiler = CausalProfiler(space_of_interest=space)

    # Generate data and queries
    print("Generating data and queries from predefined graph...")
    data, (queries, estimates), graph = profiler.generate_samples_and_queries()
    print("Done!")
    print()

    # Display the results
    print("=" * 70)
    print("Generated Data")
    print("=" * 70)
    print("\nObservable variables in the dataset:")
    for var_name, var_data in data.items():
        print(f"  {var_name}: shape {var_data.shape}")
    print("\nNote: Hidden variables are not included in the data.")
    print()

    print("=" * 70)
    print("Graph Structure")
    print("=" * 70)
    print(graph)
    adj, names = graph  # unpack
    print("\nAdjacency list (parent -> children):")
    for parent_idx, children_idxs in adj.items():
        parent_name = names[parent_idx]

        if children_idxs:
            child_names = [names[i] for i in children_idxs]
            print(f"{parent_name} -> {', '.join(child_names)}")
        else:
            print(f"{parent_name} (no children)")
    print()

    print("=" * 70)
    print("Generated Queries and Ground Truth Estimates")
    print("=" * 70)
    print(f"\nGenerated {len(queries)} queries:\n")
    for i, (query, estimate) in enumerate(zip(queries, estimates), 1):
        print(f"Query {i}:")
        print(f"  {query}")
        print(f"  Ground Truth Estimate: {estimate:.4f}")
        print()

    # Demonstrate multiple runs with the same graph
    print("=" * 70)
    print("Multiple Generations with Same Graph Structure")
    print("=" * 70)
    print("\nGenerating 3 more datasets with the same graph structure...")
    for run in range(1, 4):
        data, (queries, estimates), graph = profiler.generate_samples_and_queries()
        print(f"\nRun {run}:")
        print(f"  Generated {len(queries)} queries")
        print(f"  Sample estimate: {estimates[0]:.4f}")


if __name__ == "__main__":
    # Example:
    # python run.py examples/components/predefined_graphs/confounded_graph.yaml
    parser = argparse.ArgumentParser(description="Run with a predefined causal graph")
    parser.add_argument(
        "predefined_graph_file",
        nargs="?",  # optional
        default="examples/components/predefined_graphs/example_graph.yaml",
        help=(
            "Path to predefined graph YAML file "
            "(default: examples/components/predefined_graphs/example_graph.yaml)"
        ),
    )

    args = parser.parse_args()
    main(predefined_graph_file=args.predefined_graph_file)
