import argparse
import os
import random
import sys
import time
from itertools import combinations

import networkx as nx
import numpy as np
import pandas as pd
import torch


from causal_profiler.space_of_interest import (
    MechanismFamily,
    NoiseDistribution,
    QueryType,
    SpaceOfInterest,
    VariableDataType,
)
from verify_data_ci import verify_data_conditional_independence
from verify_do_calculus import verify_do_calculus
from verify_query_estimator_ci import verify_query_estimator_conditional_independence

# Verification functions
from verify_query_estimator_conditionals import verify_query_estimator_conditionals
from verify_structural_counterfactual_axioms import (
    verify_structural_counterfactual_axioms,
)

from causal_profiler import CausalProfiler, ErrorMetric

######################################################
# Parameter grids
######################################################

PARAMETER_GRIDS = {
    "test1": {
        "NUM_VARIABLES_LIST": range(3, 21),
        "EDGES_RATIOS": [0.2, 0.4, 0.6, 0.8, 1.0],
        "NUM_CATEGORIES_LIST": [2, 3, 4, 5, 7, 10],
        "NUM_NOISE_REGIONS_LIST": [2, 10, 30, 50, 100, 200, 500],
        "DATASET_SIZES": [100000],
        "SAMPLES_PER_COMBO": 2,
    },
    "test2": {
        "NUM_VARIABLES_LIST": range(3, 6),
        "EDGES_RATIOS": [0.2, 0.4],
        "NUM_CATEGORIES_LIST": [2, 3],
        "NUM_NOISE_REGIONS_LIST": [2, 10],
        "DATASET_SIZES": [5],
        "SAMPLES_PER_COMBO": 1,
    },
    "test3": {
        "NUM_VARIABLES_LIST": range(3, 7),
        "EDGES_RATIOS": [0.2, 0.4, 0.6],
        "NUM_CATEGORIES_LIST": [2, 3],
        "NUM_NOISE_REGIONS_LIST": [2, 10, 30],
        "DATASET_SIZES": [100000],
        "SAMPLES_PER_COMBO": 1,
    },
    "test4": {
        "NUM_VARIABLES_LIST": range(3, 8),
        "EDGES_RATIOS": [0.2, 0.4, 0.8, 1.0],
        "NUM_CATEGORIES_LIST": [2, 3, 4, 5],
        "NUM_NOISE_REGIONS_LIST": [2, 10, 100, 200, 500],
        "DATASET_SIZES": [100000],
        "SAMPLES_PER_COMBO": 10,
    },
    "toy": {
        "NUM_VARIABLES_LIST": range(3, 4),
        "EDGES_RATIOS": [0.1, 0.2, 0.3, 0.4, 0.5],
        "NUM_CATEGORIES_LIST": [3],
        "NUM_NOISE_REGIONS_LIST": [3],
        "DATASET_SIZES": [5000],
        "SAMPLES_PER_COMBO": 1,
    },
    "toy2": {
        "NUM_VARIABLES_LIST": [4],
        "EDGES_RATIOS": [0.1],
        "NUM_CATEGORIES_LIST": [2],
        "NUM_NOISE_REGIONS_LIST": [3],
        "DATASET_SIZES": [10000],
        "SAMPLES_PER_COMBO": 1,
    },
    "test5": {
        "NUM_VARIABLES_LIST": [3, 5, 10],
        "EDGES_RATIOS": [0.1, 0.5, 0.7],
        "NUM_CATEGORIES_LIST": [2, 5, 7],
        "NUM_NOISE_REGIONS_LIST": [3, 5, 10],
        "DATASET_SIZES": [50000],
        "SAMPLES_PER_COMBO": 5,
    },
    "test6": {
        "NUM_VARIABLES_LIST": [4, 5],
        "EDGES_RATIOS": [0.1, 0.4],
        "NUM_CATEGORIES_LIST": [5],
        "NUM_NOISE_REGIONS_LIST": [100],
        "DATASET_SIZES": [50000],
        "SAMPLES_PER_COMBO": 5,
    },
    "test7": {
        "NUM_VARIABLES_LIST": [4, 5, 6],
        "EDGES_RATIOS": [0.1, 0.4],
        "NUM_CATEGORIES_LIST": [2, 5],
        "NUM_NOISE_REGIONS_LIST": [5, 100],
        "DATASET_SIZES": [50000],
        "SAMPLES_PER_COMBO": 2,
    },
    "test8": {
        "NUM_VARIABLES_LIST": [4, 5, 6],
        "EDGES_RATIOS": [0.1, 0.4],
        "NUM_CATEGORIES_LIST": [2, 3, 10],
        "NUM_NOISE_REGIONS_LIST": [5, 10],
        "DATASET_SIZES": [50000],
        "SAMPLES_PER_COMBO": 5,
    },
}

# Default configuration
DEFAULT_GRID_NAME = "toy"
PROP_HIDDEN_LIST = [0.0]
DEFAULT_SEED = 43

######################################################
# Utility functions
######################################################


def configure_seed(seed):
    """Set all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def save_dataframe(df, output_dir, filename):
    """Append new records to a CSV file to avoid losing data if the script crashes."""
    os.makedirs(output_dir, exist_ok=True)
    file_path = os.path.join(output_dir, filename)

    # Save in "append" mode if the file exists, without headers
    if os.path.exists(file_path):
        df.to_csv(file_path, mode="a", header=False, index=False)
    else:
        df.to_csv(file_path, index=False)


######################################################
# Main
######################################################


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--parameter-grid",
        choices=PARAMETER_GRIDS.keys(),
        default=DEFAULT_GRID_NAME,
        help="Which parameter grid to use.",
    )
    parser.add_argument(
        "--verifications-to-run",
        nargs="+",
        default=[
            "l1_estimator_conditionals",
            "l1_data_ci",
            "l1_estimator_ci",
            "l3_structural_counterfactual_axioms",
            "l2_do_calculus",
        ],
        help="List of verifications to run, e.g. l1_estimator_conditionals l1_data_ci l1_estimator_ci l2_do_calculus l3_structural_counterfactual_axioms",
    )
    parser.add_argument(
        "--output-dir", default=".", help="Directory where CSV results are saved."
    )
    parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="Random seed.")
    args = parser.parse_args()

    # Log arguments
    print("Running with arguments:", args)

    # Configure seeds
    configure_seed(args.seed)

    # Retrieve parameter grid
    grid_params = PARAMETER_GRIDS[args.parameter_grid]
    num_vars_list = grid_params["NUM_VARIABLES_LIST"]
    edges_ratios = grid_params["EDGES_RATIOS"]
    num_categories_list = grid_params["NUM_CATEGORIES_LIST"]
    num_noise_regions_list = grid_params["NUM_NOISE_REGIONS_LIST"]
    dataset_sizes = grid_params["DATASET_SIZES"]
    samples_per_combo = grid_params["SAMPLES_PER_COMBO"]

    # Accumulators
    # for verify_query_estimator_conditionals(A|C)
    l1_estimator_conditionals_records = []
    # for verify_data_conditional_independence(A,B|C)
    l1_data_ci_records = []
    # for verify_query_estimator_conditional_independence(A,B|C)
    l1_estimator_ci_records = []
    # for verify_structural_counterfactual_axioms(SCM)
    l3_structural_counterfactual_axioms_records = []
    # for verify_do_calculus
    l2_do_calculus_records = []

    # Timers
    conditional_time_total = 0.0
    conditional_count = 0
    data_ci_time_total = 0.0
    data_ci_count = 0
    est_ci_time_total = 0.0
    est_ci_count = 0
    counterfactual_axioms_time_total = 0.0
    counterfactual_axioms_count = 0
    do_calculus_time_total = 0.0
    do_calculus_count = 0

    # Orchestrate
    for N in num_vars_list:
        for edge_ratio in edges_ratios:
            for hidden_prop in PROP_HIDDEN_LIST:
                for num_cats in num_categories_list:
                    for noise_regions in num_noise_regions_list:
                        for dataset_size in dataset_sizes:
                            print(
                                f"\n--- PARAMS: N={N}, "
                                f"Edge Ratio={edge_ratio}, "
                                f"Hidden Prop={hidden_prop}, "
                                f"Num Cats={num_cats}, "
                                f"Noise Regions={noise_regions}, "
                                f"Data Size={dataset_size} ---"
                            )

                            # Build space-of-interest
                            expected_edges = int(edge_ratio * N * (N - 1))
                            space = SpaceOfInterest(
                                number_of_nodes=(N, N),
                                variable_dimensionality=(1, 1),
                                expected_edges=str(expected_edges),
                                mechanism_family=MechanismFamily.TABULAR,
                                noise_distribution=NoiseDistribution.UNIFORM,
                                noise_args=[-1, 1],
                                number_of_noise_regions=str(noise_regions),
                                variable_type=VariableDataType.DISCRETE,
                                number_of_categories=(num_cats, num_cats),
                                proportion_of_hidden_variables=hidden_prop,
                                number_of_queries=1,  # not testing the queries here
                                query_type=QueryType.CONDITIONAL,
                                number_of_data_points=dataset_size,
                            )

                            profiler = CausalProfiler(
                                space_of_interest=space,
                                metric=ErrorMetric.L2,
                                return_adjacency_matrix=False,
                            )

                            for _ in range(samples_per_combo):
                                # Generate data
                                (
                                    data_dict,
                                    (queries, estimates),
                                    (graph_dict, index_to_var),
                                ) = profiler.generate_samples_and_queries()

                                # Convert adjacency to networkx DiGraph
                                dag = nx.DiGraph()
                                for node_idx in range(len(index_to_var)):
                                    dag.add_node(node_idx)
                                for parent_idx, children_idxs in graph_dict.items():
                                    for child_idx in children_idxs:
                                        dag.add_edge(parent_idx, child_idx)

                                var_names = [index_to_var[i] for i in dag.nodes()]
                                var_to_index = {
                                    vname: i for i, vname in enumerate(var_names)
                                }

                                ##########################################
                                # (A) Single conditionals: P(A|C)
                                # Verify single-variable distributions P(A|C)
                                # For each variable A, pick all subsets C of size 1..3
                                # from the other variables in the DAG.
                                ##########################################
                                if (
                                    "l1_estimator_conditionals"
                                    in args.verifications_to_run
                                ):
                                    for A_name in var_names:
                                        others = [v for v in var_names if v != A_name]
                                        # All combos of size 1, 2, or 3
                                        possible_C_sets = []

                                        for r in [1, 2, 3]:
                                            for combo in combinations(others, r):
                                                possible_C_sets.append(list(combo))

                                        # Evaluate for each combo
                                        for C_names in possible_C_sets:
                                            start_time = time.time()
                                            cond_res = verify_query_estimator_conditionals(
                                                query_estimator=profiler.sampler.query_estimator,
                                                scm=profiler.sampler._scm,
                                                A=A_name,
                                                C=C_names,
                                                data_dict=data_dict,
                                                use_multi_query=False,  # multi-query not implemented
                                                js_threshold=0.05,
                                                aggregation_method="mean",
                                                n_samples=50000,
                                                min_count=2,
                                            )
                                            end_time = time.time()
                                            conditional_time_total += (
                                                end_time - start_time
                                            )
                                            conditional_count += 1
                                            # TODO: Simplify record
                                            rec = {
                                                key: cond_res[key]
                                                for key in [
                                                    "agreement_accepted",
                                                    "aggregation",
                                                    "threshold",
                                                    "aggregation_method",
                                                    "num_c_values",
                                                    "num_queries",
                                                ]
                                                # "test_type": "single_conditional",
                                                # "A": A_name,
                                                # "C": C_names,
                                                # **cond_res,
                                                # "N": N,
                                                # "edge_ratio": edge_ratio,
                                                # "hidden_prop": hidden_prop,
                                                # "num_cats": num_cats,
                                                # "noise_regions": noise_regions,
                                                # "dataset_size": dataset_size,
                                            }
                                            save_dataframe(
                                                pd.DataFrame([rec]),
                                                args.output_dir,
                                                "l1_single_conditionals_results.csv",
                                            )
                                            l1_estimator_conditionals_records.append(
                                                rec
                                            )

                                ##########################################
                                # (B & C) Conditional Independence: (A,B | C)
                                ##########################################
                                if ("l1_data_ci" in args.verifications_to_run) or (
                                    "l1_estimator_ci" in args.verifications_to_run
                                ):
                                    for i in range(len(var_names)):
                                        for j in range(i + 1, len(var_names)):
                                            # The set of all other variables (besides A,B)
                                            A_name = var_names[i]
                                            B_name = var_names[j]
                                            others = [
                                                v
                                                for v in var_names
                                                if v not in (A_name, B_name)
                                            ]
                                            # Generate combos of size 1..3
                                            c_sets = []
                                            for r in [1, 2, 3]:
                                                for combo in combinations(others, r):
                                                    c_sets.append(list(combo))

                                            for C_names in c_sets:
                                                A_idx = var_to_index[A_name]
                                                B_idx = var_to_index[B_name]
                                                C_idxs = [
                                                    var_to_index[c] for c in C_names
                                                ]

                                                # Check d-separation
                                                if not nx.is_d_separator(
                                                    dag,
                                                    {A_idx},
                                                    {B_idx},
                                                    set(C_idxs),
                                                ):
                                                    # Skip if not d-separated
                                                    continue

                                                ##########################################################
                                                # (B) Data-based CI
                                                # Verify conditional independence: A ⟂ B | C_set
                                                # For each pair (A,B), pick C combos of size 1..3 from the rest.
                                                # Skip if DAG does not show that A,B are d-separated by C.
                                                ##########################################################
                                                if (
                                                    "l1_data_ci"
                                                    in args.verifications_to_run
                                                ):
                                                    start_time = time.time()
                                                    data_ci_res = verify_data_conditional_independence(
                                                        data_dict=data_dict,
                                                        A=A_name,
                                                        B=B_name,
                                                        C_set=C_names,
                                                        alpha=0.05,
                                                        correction="BH",
                                                    )
                                                    end_time = time.time()
                                                    data_ci_time_total += (
                                                        end_time - start_time
                                                    )
                                                    data_ci_count += 1
                                                    data_ci_rec = {
                                                        key: data_ci_res[key]
                                                        for key in [
                                                            "alpha",
                                                            "correction",
                                                            "independence_accepted",
                                                            "distribution_results",
                                                            "skipped_distributions",
                                                            "rejection_rate",
                                                        ]
                                                        # "test_type": "data_cond_indep",
                                                        # "A": A_name,
                                                        # "B": B_name,
                                                        # "C": C_names,
                                                        # **data_ci_res,
                                                        # "N": N,
                                                        # "edge_ratio": edge_ratio,
                                                        # "hidden_prop": hidden_prop,
                                                        # "num_cats": num_cats,
                                                        # "noise_regions": noise_regions,
                                                        # "dataset_size": dataset_size,
                                                    }
                                                    save_dataframe(
                                                        pd.DataFrame([data_ci_rec]),
                                                        args.output_dir,
                                                        "l1_data_ci_results.csv",
                                                    )
                                                    l1_data_ci_records.append(
                                                        data_ci_rec
                                                    )

                                                ##########################################################
                                                # (C) Estimator-based CI
                                                # Check distance of P(A, B | C) and P(A | C) * P(B | C) using JS divergence
                                                # For A, B that are d-separated by C
                                                ##########################################################
                                                if (
                                                    "l1_estimator_ci"
                                                    in args.verifications_to_run
                                                ):
                                                    start_time = time.time()
                                                    est_ci_res = verify_query_estimator_conditional_independence(
                                                        query_estimator=profiler.sampler.query_estimator,
                                                        scm=profiler.sampler._scm,
                                                        A=A_name,
                                                        B=B_name,
                                                        C_set=C_names,
                                                        data_dict=data_dict,
                                                        use_multi_query=False,  # multi-query not implemented
                                                        js_threshold=0.15,
                                                        aggregation_method="median",  # to defend against outliers -> extremely large error due to lack of data
                                                        n_samples=50000,
                                                    )
                                                    end_time = time.time()
                                                    est_ci_time_total += (
                                                        end_time - start_time
                                                    )
                                                    est_ci_count += 1
                                                    est_ci_rec = {
                                                        key: est_ci_res[key]
                                                        for key in [
                                                            "aggregation",
                                                            "independence_accepted",
                                                            "threshold",
                                                            "aggregation_method",
                                                            "num_queries",
                                                            "C_values",
                                                            "skipped_c_values",
                                                            "js_divergences",
                                                        ]
                                                        # "test_type": "est_cond_indep",
                                                        # "A": A_name,
                                                        # "B": B_name,
                                                        # "C": C_names,
                                                        # **est_ci_res,
                                                        # "N": N,
                                                        # "edge_ratio": edge_ratio,
                                                        # "hidden_prop": hidden_prop,
                                                        # "num_cats": num_cats,
                                                        # "noise_regions": noise_regions,
                                                        # "dataset_size": dataset_size,
                                                    }
                                                    save_dataframe(
                                                        pd.DataFrame([est_ci_rec]),
                                                        args.output_dir,
                                                        "l1_estimator_ci_results.csv",
                                                    )
                                                    l1_estimator_ci_records.append(
                                                        est_ci_rec
                                                    )

                                ##########################################
                                # L3 Structural Counterfactual Axioms
                                ##########################################
                                if (
                                    "l3_structural_counterfactual_axioms"
                                    in args.verifications_to_run
                                ):
                                    start_time = time.time()
                                    axioms_res = (
                                        verify_structural_counterfactual_axioms(
                                            scm=profiler.sampler._scm,
                                            n_tests=10,
                                        )
                                    )
                                    end_time = time.time()
                                    counterfactual_axioms_time_total += (
                                        end_time - start_time
                                    )
                                    counterfactual_axioms_count += 1

                                    axioms_rec = {
                                        "composition_success_rate": axioms_res[
                                            "composition_success_rate"
                                        ],
                                        "effectiveness_success_rate": axioms_res[
                                            "effectiveness_success_rate"
                                        ],
                                        "reversibility_success_rate": axioms_res[
                                            "reversibility_success_rate"
                                        ],
                                    }
                                    save_dataframe(
                                        pd.DataFrame([axioms_rec]),
                                        args.output_dir,
                                        "l3_structural_counterfactual_axioms_results.csv",
                                    )
                                    l3_structural_counterfactual_axioms_records.append(
                                        axioms_rec
                                    )

                                ##########################################
                                # L2 Do-Calculus
                                ##########################################
                                if "l2_do_calculus" in args.verifications_to_run:
                                    # For each set of variables Y, Z, X, W
                                    # We need at least 4 variables for a reasonable test
                                    if len(var_names) >= 4:
                                        # Get unique combinations of 4 variables
                                        variable_combos = []
                                        for i in range(len(var_names)):
                                            for j in range(len(var_names)):
                                                if j != i:
                                                    for k in range(len(var_names)):
                                                        if k != i and k != j:
                                                            for m in range(
                                                                len(var_names)
                                                            ):
                                                                if (
                                                                    m != i
                                                                    and m != j
                                                                    and m != k
                                                                ):
                                                                    variable_combos.append(
                                                                        (i, j, k, m)
                                                                    )

                                        for i, j, k, m in variable_combos:
                                            Y = var_names[i]  # Target variable
                                            Z = [
                                                var_names[j]
                                            ]  # Variable to condition/intervene on
                                            X = [var_names[k]]  # Intervention variable
                                            W = [
                                                var_names[m]
                                            ]  # Additional conditioning variable

                                            start_time = time.time()
                                            do_calculus_res = verify_do_calculus(
                                                query_estimator=profiler.sampler.query_estimator,
                                                scm=profiler.sampler._scm,
                                                Y=Y,
                                                Z=Z,
                                                X=X,
                                                W=W,
                                                data_dict=data_dict,
                                                alpha=0.05,
                                                correction="BH",
                                                n_samples=10000,
                                            )
                                            end_time = time.time()
                                            do_calculus_time_total += (
                                                end_time - start_time
                                            )
                                            do_calculus_count += 1

                                            # Extract key information for recording
                                            do_calculus_rec = {
                                                "variables_Y": Y,
                                                "variables_Z": str(Z),
                                                "variables_X": str(X),
                                                "variables_W": str(W),
                                                "rule1_d_separation": (
                                                    do_calculus_res["rule1_results"][
                                                        "d_separation_holds"
                                                    ]
                                                    if do_calculus_res["rule1_results"][
                                                        "d_separation_holds"
                                                    ]
                                                    is not None
                                                    else False
                                                ),
                                                "rule2_d_separation": (
                                                    do_calculus_res["rule2_results"][
                                                        "d_separation_holds"
                                                    ]
                                                    if do_calculus_res["rule2_results"][
                                                        "d_separation_holds"
                                                    ]
                                                    is not None
                                                    else False
                                                ),
                                                "rule3_d_separation": (
                                                    do_calculus_res["rule3_results"][
                                                        "d_separation_holds"
                                                    ]
                                                    if do_calculus_res["rule3_results"][
                                                        "d_separation_holds"
                                                    ]
                                                    is not None
                                                    else False
                                                ),
                                            }

                                            # Only include statistical test results if d-separation holds
                                            if do_calculus_res["rule1_results"][
                                                "d_separation_holds"
                                            ]:
                                                rule_holds = do_calculus_res[
                                                    "rule1_results"
                                                ].get("rule_holds")
                                                rej_rate = do_calculus_res[
                                                    "rule1_results"
                                                ].get("rejection_rate")

                                                # Add distribution results
                                                dist_results = do_calculus_res[
                                                    "rule1_results"
                                                ].get("distribution_results", [])
                                                skipped_dist = do_calculus_res[
                                                    "rule1_results"
                                                ].get("skipped_distributions", [])

                                                # Add metrics to record
                                                do_calculus_rec.update(
                                                    {
                                                        "rule1_rule_holds": (
                                                            rule_holds
                                                            if rule_holds is not None
                                                            else False
                                                        ),
                                                        "rule1_rejection_rate": (
                                                            rej_rate
                                                            if rej_rate is not None
                                                            else 1.0
                                                        ),
                                                        "rule1_total_distributions": len(
                                                            dist_results
                                                        )
                                                        + len(skipped_dist),
                                                        "rule1_total_skipped": len(
                                                            skipped_dist
                                                        ),
                                                    }
                                                )

                                                # Add individual distribution results
                                                for i, dist in enumerate(dist_results):
                                                    do_calculus_rec[
                                                        f"rule1_distribution_{i}_result"
                                                    ] = dist

                                                # Add skipped distributions
                                                for i, skipped in enumerate(
                                                    skipped_dist
                                                ):
                                                    do_calculus_rec[
                                                        f"rule1_skipped_{i}"
                                                    ] = skipped

                                            if do_calculus_res["rule2_results"][
                                                "d_separation_holds"
                                            ]:
                                                rule_holds = do_calculus_res[
                                                    "rule2_results"
                                                ].get("rule_holds")
                                                rej_rate = do_calculus_res[
                                                    "rule2_results"
                                                ].get("rejection_rate")

                                                # Add distribution results
                                                dist_results = do_calculus_res[
                                                    "rule2_results"
                                                ].get("distribution_results", [])
                                                skipped_dist = do_calculus_res[
                                                    "rule2_results"
                                                ].get("skipped_distributions", [])

                                                # Add metrics to record
                                                do_calculus_rec.update(
                                                    {
                                                        "rule2_rule_holds": (
                                                            rule_holds
                                                            if rule_holds is not None
                                                            else False
                                                        ),
                                                        "rule2_rejection_rate": (
                                                            rej_rate
                                                            if rej_rate is not None
                                                            else 1.0
                                                        ),
                                                        "rule2_total_distributions": len(
                                                            dist_results
                                                        )
                                                        + len(skipped_dist),
                                                        "rule2_total_skipped": len(
                                                            skipped_dist
                                                        ),
                                                    }
                                                )

                                                # Add individual distribution results
                                                for i, dist in enumerate(dist_results):
                                                    do_calculus_rec[
                                                        f"rule2_distribution_{i}_result"
                                                    ] = dist

                                                # Add skipped distributions
                                                for i, skipped in enumerate(
                                                    skipped_dist
                                                ):
                                                    do_calculus_rec[
                                                        f"rule2_skipped_{i}"
                                                    ] = skipped

                                            if do_calculus_res["rule3_results"][
                                                "d_separation_holds"
                                            ]:
                                                rule_holds = do_calculus_res[
                                                    "rule3_results"
                                                ].get("rule_holds")
                                                rej_rate = do_calculus_res[
                                                    "rule3_results"
                                                ].get("rejection_rate")

                                                # Add distribution results
                                                dist_results = do_calculus_res[
                                                    "rule3_results"
                                                ].get("distribution_results", [])
                                                skipped_dist = do_calculus_res[
                                                    "rule3_results"
                                                ].get("skipped_distributions", [])

                                                # Add metrics to record
                                                do_calculus_rec.update(
                                                    {
                                                        "rule3_rule_holds": (
                                                            rule_holds
                                                            if rule_holds is not None
                                                            else False
                                                        ),
                                                        "rule3_rejection_rate": (
                                                            rej_rate
                                                            if rej_rate is not None
                                                            else 1.0
                                                        ),
                                                        "rule3_total_distributions": len(
                                                            dist_results
                                                        )
                                                        + len(skipped_dist),
                                                        "rule3_total_skipped": len(
                                                            skipped_dist
                                                        ),
                                                    }
                                                )

                                                # Add individual distribution results
                                                for i, dist in enumerate(dist_results):
                                                    do_calculus_rec[
                                                        f"rule3_distribution_{i}_result"
                                                    ] = dist

                                                # Add skipped distributions
                                                for i, skipped in enumerate(
                                                    skipped_dist
                                                ):
                                                    do_calculus_rec[
                                                        f"rule3_skipped_{i}"
                                                    ] = skipped

                                            # Only append if at least one d-separation holds
                                            if (
                                                do_calculus_rec["rule1_d_separation"]
                                                or do_calculus_rec["rule2_d_separation"]
                                                or do_calculus_rec["rule3_d_separation"]
                                            ):
                                                save_dataframe(
                                                    pd.DataFrame([do_calculus_rec]),
                                                    args.output_dir,
                                                    "l2_do_calculus_results.csv",
                                                )
                                                l2_do_calculus_records.append(
                                                    do_calculus_rec
                                                )

    ##########################################
    # Print timing info
    ##########################################
    if conditional_count > 0:
        print(
            f"verify_query_estimator_conditionals: total={conditional_time_total:.2f}s, "
            f"avg={conditional_time_total/conditional_count:.4f}s"
        )
    if data_ci_count > 0:
        print(
            f"verify_data_conditional_independence: total={data_ci_time_total:.2f}s, "
            f"avg={data_ci_time_total/data_ci_count:.4f}s"
        )
    if est_ci_count > 0:
        print(
            f"verify_query_estimator_conditional_independence: total={est_ci_time_total:.2f}s, "
            f"avg={est_ci_time_total/est_ci_count:.4f}s"
        )
    if counterfactual_axioms_count > 0:
        print(
            f"verify_structural_counterfactual_axioms: total={counterfactual_axioms_time_total:.2f}s, "
            f"avg={counterfactual_axioms_time_total/counterfactual_axioms_count:.4f}s"
        )
    if do_calculus_count > 0:
        print(
            f"verify_do_calculus: total={do_calculus_time_total:.2f}s, "
            f"avg={do_calculus_time_total/do_calculus_count:.4f}s"
        )

    ##########################################
    # High-level summarization (for debug)
    ##########################################
    def print_positive_percentage(df, column, title):
        if column in df.columns and not df.empty:
            if df[column].dtype == bool:  # Handle boolean columns
                counts = df[column].value_counts(normalize=True) * 100
                true_percentage = counts.get(True, 0)
                print(f"{title}: True is {true_percentage:.2f}%")
            elif df[column].dtype in [float, int]:  # Handle numerical columns
                mean_percentage = df[column].mean() * 100  # Scale to percentage
                print(f"{title}: {mean_percentage:.2f}%")

    # TODO: this is unnecessary and uses lots of memory
    l1_single_cond_df = pd.DataFrame(l1_estimator_conditionals_records)
    l1_data_ci_df = pd.DataFrame(l1_data_ci_records)
    l1_estimator_ci_df = pd.DataFrame(l1_estimator_ci_records)
    l3_axioms_df = pd.DataFrame(l3_structural_counterfactual_axioms_records)
    l2_do_calculus_df = pd.DataFrame(l2_do_calculus_records)

    print_positive_percentage(
        l1_single_cond_df, "agreement_accepted", "Agreement Accepted (A|C)"
    )
    print_positive_percentage(
        l1_data_ci_df, "independence_accepted", "Independence Accepted (Data)"
    )
    print_positive_percentage(
        l1_estimator_ci_df, "independence_accepted", "Independence Accepted (Estimator)"
    )

    # Calculate overall do-calculus success rates
    if not l2_do_calculus_df.empty:
        rule1_tests = l2_do_calculus_df["rule1_d_separation"].sum()
        rule2_tests = l2_do_calculus_df["rule2_d_separation"].sum()
        rule3_tests = l2_do_calculus_df["rule3_d_separation"].sum()

        # Initialize counters for the summary table
        rule_data = {
            "Rule 1 (Insertion/deletion of observation)": {
                "tests": 0,
                "passed": 0,
                "failed": 0,
                "skipped": 0,
                "total_distributions": 0,
                "total_skipped_distributions": 0,
                "total_failed_distributions": 0,
            },
            "Rule 2 (Action/observation exchange)": {
                "tests": 0,
                "passed": 0,
                "failed": 0,
                "skipped": 0,
                "total_distributions": 0,
                "total_skipped_distributions": 0,
                "total_failed_distributions": 0,
            },
            "Rule 3 (Insertion/deletion of action)": {
                "tests": 0,
                "passed": 0,
                "failed": 0,
                "skipped": 0,
                "total_distributions": 0,
                "total_skipped_distributions": 0,
                "total_failed_distributions": 0,
            },
        }

        # Process rule 1 results
        if rule1_tests > 0 and "rule1_rule_holds" in l2_do_calculus_df.columns:
            # Count tests where d-separation holds
            valid_tests = l2_do_calculus_df[
                l2_do_calculus_df["rule1_d_separation"] == True
            ]

            # Count successes (rule holds) and failures
            if l2_do_calculus_df["rule1_rule_holds"].dtype == bool:
                rule1_successes = valid_tests["rule1_rule_holds"].sum()
            else:
                rule1_successes = valid_tests[
                    valid_tests["rule1_rule_holds"] == True
                ].shape[0]

            rule1_failures = rule1_tests - rule1_successes

            # Update rule data for table
            rule_data["Rule 1 (Insertion/deletion of observation)"][
                "tests"
            ] = rule1_tests
            rule_data["Rule 1 (Insertion/deletion of observation)"][
                "passed"
            ] = rule1_successes
            rule_data["Rule 1 (Insertion/deletion of observation)"][
                "failed"
            ] = rule1_failures

        # Process rule 2 results
        if rule2_tests > 0 and "rule2_rule_holds" in l2_do_calculus_df.columns:
            # Count tests where d-separation holds
            valid_tests = l2_do_calculus_df[
                l2_do_calculus_df["rule2_d_separation"] == True
            ]

            # Count successes (rule holds) and failures
            if l2_do_calculus_df["rule2_rule_holds"].dtype == bool:
                rule2_successes = valid_tests["rule2_rule_holds"].sum()
            else:
                rule2_successes = valid_tests[
                    valid_tests["rule2_rule_holds"] == True
                ].shape[0]

            rule2_failures = rule2_tests - rule2_successes

            # Update rule data for table
            rule_data["Rule 2 (Action/observation exchange)"]["tests"] = rule2_tests
            rule_data["Rule 2 (Action/observation exchange)"][
                "passed"
            ] = rule2_successes
            rule_data["Rule 2 (Action/observation exchange)"]["failed"] = rule2_failures

        # Process rule 3 results
        if rule3_tests > 0 and "rule3_rule_holds" in l2_do_calculus_df.columns:
            # Count tests where d-separation holds
            valid_tests = l2_do_calculus_df[
                l2_do_calculus_df["rule3_d_separation"] == True
            ]

            # Count successes (rule holds) and failures
            if l2_do_calculus_df["rule3_rule_holds"].dtype == bool:
                rule3_successes = valid_tests["rule3_rule_holds"].sum()
            else:
                rule3_successes = valid_tests[
                    valid_tests["rule3_rule_holds"] == True
                ].shape[0]

            rule3_failures = rule3_tests - rule3_successes

            # Update rule data for table
            rule_data["Rule 3 (Insertion/deletion of action)"]["tests"] = rule3_tests
            rule_data["Rule 3 (Insertion/deletion of action)"][
                "passed"
            ] = rule3_successes
            rule_data["Rule 3 (Insertion/deletion of action)"][
                "failed"
            ] = rule3_failures

        # Print the summary table for all rules
        print("\n" + "=" * 80)
        print("DO-CALCULUS OVERALL SUMMARY TABLE")
        print("=" * 80)

        # Print header
        header = [
            "Rule",
            "Total Tests",
            "Passed",
            "Failed",
            "Skipped",
            "Total Distributions",
            "Total Skipped",
            "Total Failed",
            "Rejection Rate",
        ]
        header_format = "{:<40} {:<12} {:<12} {:<12} {:<12} {:<15} {:<15} {:<15} {:<12}"
        row_format = "{:<40} {:<12} {:<12} {:<12} {:<12} {:<15} {:<15} {:<15} {:<12}"

        print(header_format.format(*header))
        print("-" * 80)

        # Initialize distribution counters in rule_data if they don't exist
        for rule_name in rule_data:
            if "total_distributions" not in rule_data[rule_name]:
                rule_data[rule_name].update(
                    {
                        "total_distributions": 0,
                        "total_skipped": 0,
                        "total_failed": 0,
                    }
                )

        # Calculate distribution metrics from the records
        rule1_dist_total = 0
        rule1_skipped_total = 0
        rule1_failed_total = 0

        rule2_dist_total = 0
        rule2_skipped_total = 0
        rule2_failed_total = 0

        rule3_dist_total = 0
        rule3_skipped_total = 0
        rule3_failed_total = 0

        for rec in l2_do_calculus_records:
            # Rule 1 distribution metrics
            if "rule1_total_distributions" in rec:
                rule1_dist_total += rec["rule1_total_distributions"]
            if "rule1_total_skipped" in rec:
                rule1_skipped_total += rec["rule1_total_skipped"]

            # Count failed distributions by checking for rejected=True
            for key, val in rec.items():
                if key.startswith("rule1_distribution_") and key.endswith("_result"):
                    if val.get("rejected", False):
                        rule1_failed_total += 1

            # Rule 2 distribution metrics
            if "rule2_total_distributions" in rec:
                rule2_dist_total += rec["rule2_total_distributions"]
            if "rule2_total_skipped" in rec:
                rule2_skipped_total += rec["rule2_total_skipped"]

            # Count failed distributions by checking for rejected=True
            for key, val in rec.items():
                if key.startswith("rule2_distribution_") and key.endswith("_result"):
                    if val.get("rejected", False):
                        rule2_failed_total += 1

            # Rule 3 distribution metrics
            if "rule3_total_distributions" in rec:
                rule3_dist_total += rec["rule3_total_distributions"]
            if "rule3_total_skipped" in rec:
                rule3_skipped_total += rec["rule3_total_skipped"]

            # Count failed distributions by checking for rejected=True
            for key, val in rec.items():
                if key.startswith("rule3_distribution_") and key.endswith("_result"):
                    if val.get("rejected", False):
                        rule3_failed_total += 1

        # Update rule_data with distribution totals
        rule_data["Rule 1 (Insertion/deletion of observation)"][
            "total_distributions"
        ] = rule1_dist_total
        rule_data["Rule 1 (Insertion/deletion of observation)"][
            "total_skipped"
        ] = rule1_skipped_total
        rule_data["Rule 1 (Insertion/deletion of observation)"][
            "total_failed"
        ] = rule1_failed_total

        rule_data["Rule 2 (Action/observation exchange)"][
            "total_distributions"
        ] = rule2_dist_total
        rule_data["Rule 2 (Action/observation exchange)"][
            "total_skipped"
        ] = rule2_skipped_total
        rule_data["Rule 2 (Action/observation exchange)"][
            "total_failed"
        ] = rule2_failed_total

        rule_data["Rule 3 (Insertion/deletion of action)"][
            "total_distributions"
        ] = rule3_dist_total
        rule_data["Rule 3 (Insertion/deletion of action)"][
            "total_skipped"
        ] = rule3_skipped_total
        rule_data["Rule 3 (Insertion/deletion of action)"][
            "total_failed"
        ] = rule3_failed_total

        # Print rows
        for rule_name, data in rule_data.items():
            total = data["tests"]
            passed = data["passed"]
            failed = data["failed"]
            skipped = data.get("skipped", 0)

            # Get distribution counts
            total_distributions = data.get("total_distributions", 0)
            total_skipped = data.get("total_skipped", 0)
            total_failed = data.get("total_failed", 0)

            # Calculate rejection rate and format as percentage
            valid_tests = total - skipped
            rejection_rate = (failed / valid_tests) * 100 if valid_tests > 0 else 0.0

            # Print row
            print(
                row_format.format(
                    rule_name,
                    str(total),
                    str(passed),
                    str(failed),
                    str(skipped),
                    str(total_distributions),
                    str(total_skipped),
                    str(total_failed),
                    f"{rejection_rate:.2f}%",  # Format as percentage with 2 decimal places
                )
            )

        print("=" * 80)

    # Print summary table for data conditional independence tests
    if not l1_data_ci_df.empty:
        print("\n" + "=" * 80)
        print("CONDITIONAL INDEPENDENCE TESTS SUMMARY TABLE")
        print("=" * 80)

        # Print header
        header = [
            "Test Type",
            "Total Tests",
            "Passed",
            "Failed",
            "Skipped",
            "Total Distributions",
            "Total Skipped",
            "Total Passed",
            "Total Failed",
            "Rejection Rate",
        ]
        header_format = (
            "{:<30} {:<12} {:<12} {:<12} {:<12} {:<18} {:<15} {:<15} {:<15} {:<12}"
        )
        row_format = (
            "{:<30} {:<12} {:<12} {:<12} {:<12} {:<18} {:<15} {:<15} {:<15} {:<12}"
        )

        print(header_format.format(*header))
        print("-" * 80)

        # For Data CI tests
        total_tests = len(l1_data_ci_df)
        if total_tests > 0:
            # Initialize dictionaries to track metrics by C_set size
            c_set_sizes = {}

            # Initialize a record for the overall results
            overall_metrics = {
                "total": 0,
                "passed": 0,
                "failed": 0,
                "skipped": 0,
                "total_distributions": 0,
                "total_skipped_distributions": 0,
                "total_passed_distributions": 0,
                "total_failed_distributions": 0,
            }

            # Process each test record to calculate metrics by C_set size
            for i, rec in enumerate(l1_data_ci_records):
                # Determine the size of C_set for this test
                c_set_size = None
                try:
                    # Try to extract C_set from filename if it exists
                    if "C" in rec and isinstance(rec["C"], list):
                        c_set_size = len(rec["C"])
                    # If not found, we'll calculate it based on the actual distributions
                    elif "distribution_results" in rec and rec["distribution_results"]:
                        # Get first key which is a tuple of C values
                        first_key = next(iter(rec["distribution_results"]))
                        c_set_size = (
                            len(first_key) if isinstance(first_key, tuple) else 1
                        )
                    elif (
                        "skipped_distributions" in rec and rec["skipped_distributions"]
                    ):
                        first_key = next(iter(rec["skipped_distributions"]))
                        c_set_size = (
                            len(first_key) if isinstance(first_key, tuple) else 1
                        )
                    else:
                        # If C_set size cannot be determined, use "Unknown"
                        c_set_size = "Unknown"
                except (KeyError, StopIteration):
                    c_set_size = "Unknown"

                # Initialize metrics for this C_set size if it doesn't exist
                if c_set_size not in c_set_sizes:
                    c_set_sizes[c_set_size] = {
                        "total": 0,
                        "passed": 0,
                        "failed": 0,
                        "skipped": 0,
                        "total_distributions": 0,
                        "total_skipped_distributions": 0,
                        "total_passed_distributions": 0,
                        "total_failed_distributions": 0,
                    }

                # Increment total for this C_set size
                c_set_sizes[c_set_size]["total"] += 1
                overall_metrics["total"] += 1

                # Check test result (independence accepted/rejected/skipped)
                if "independence_accepted" in rec:
                    if rec["independence_accepted"] is None:
                        c_set_sizes[c_set_size]["skipped"] += 1
                        overall_metrics["skipped"] += 1
                    elif rec["independence_accepted"] is True:
                        c_set_sizes[c_set_size]["passed"] += 1
                        overall_metrics["passed"] += 1
                    else:
                        c_set_sizes[c_set_size]["failed"] += 1
                        overall_metrics["failed"] += 1

                # Count distributions
                if "distribution_results" in rec:
                    dist_results = rec["distribution_results"]
                    n_distributions = len(dist_results)
                    c_set_sizes[c_set_size]["total_distributions"] += n_distributions
                    overall_metrics["total_distributions"] += n_distributions

                    # Count passed and failed at the distribution level
                    for result in dist_results.values():
                        if not result["rejected"]:
                            c_set_sizes[c_set_size]["total_passed_distributions"] += 1
                            overall_metrics["total_passed_distributions"] += 1
                        else:
                            c_set_sizes[c_set_size]["total_failed_distributions"] += 1
                            overall_metrics["total_failed_distributions"] += 1

                # Count skipped distributions
                if "skipped_distributions" in rec:
                    n_skipped = len(rec["skipped_distributions"])
                    c_set_sizes[c_set_size]["total_skipped_distributions"] += n_skipped
                    overall_metrics["total_skipped_distributions"] += n_skipped
                    c_set_sizes[c_set_size][
                        "total_distributions"
                    ] += n_skipped  # Count skipped as part of total
                    overall_metrics["total_distributions"] += n_skipped

            # Print rows for each C_set size
            for c_size, metrics in sorted(c_set_sizes.items()):
                # Calculate rejection rate
                valid_tests = metrics["total"] - metrics["skipped"]
                rejection_rate = (
                    (metrics["failed"] / valid_tests) * 100 if valid_tests > 0 else 0.0
                )

                # Construct label based on C_set size
                if c_size == "Unknown":
                    label = "Data CI (Unknown C size)"
                else:
                    label = f"Data CI (|C| = {c_size})"

                # Print row
                print(
                    row_format.format(
                        label,
                        str(metrics["total"]),
                        str(metrics["passed"]),
                        str(metrics["failed"]),
                        str(metrics["skipped"]),
                        str(metrics["total_distributions"]),
                        str(metrics["total_skipped_distributions"]),
                        str(metrics["total_passed_distributions"]),
                        str(metrics["total_failed_distributions"]),
                        f"{rejection_rate:.2f}%",
                    )
                )

            # Print overall row
            valid_tests = overall_metrics["total"] - overall_metrics["skipped"]
            rejection_rate = (
                (overall_metrics["failed"] / valid_tests) * 100
                if valid_tests > 0
                else 0.0
            )

            # Print a separator line
            print("-" * 80)

            # Print overall row
            print(
                row_format.format(
                    "Data CI (All C sizes)",
                    str(overall_metrics["total"]),
                    str(overall_metrics["passed"]),
                    str(overall_metrics["failed"]),
                    str(overall_metrics["skipped"]),
                    str(overall_metrics["total_distributions"]),
                    str(overall_metrics["total_skipped_distributions"]),
                    str(overall_metrics["total_passed_distributions"]),
                    str(overall_metrics["total_failed_distributions"]),
                    f"{rejection_rate:.2f}%",
                )
            )

        # For Estimator CI tests
        if not l1_estimator_ci_df.empty:
            total_tests = len(l1_estimator_ci_df)
            if total_tests > 0:
                # Initialize dictionaries to track metrics by C_set size
                est_c_set_sizes = {}

                # Initialize a record for the overall results
                est_overall_metrics = {
                    "total": 0,
                    "passed": 0,
                    "failed": 0,
                    "skipped": 0,
                    "total_distributions": 0,
                    "total_skipped_distributions": 0,
                    "total_passed_distributions": 0,
                    "total_failed_distributions": 0,
                }

                # Process each test record to calculate metrics by C_set size
                for i, rec in enumerate(l1_estimator_ci_records):
                    # Determine the size of C_set for this test
                    c_set_size = None
                    try:
                        # Try to extract C_set from C_values if it exists
                        if "C_values" in rec and rec["C_values"]:
                            first_value = (
                                rec["C_values"][0] if rec["C_values"] else None
                            )
                            c_set_size = (
                                len(first_value)
                                if isinstance(first_value, tuple)
                                else 1
                            )
                        elif "skipped_c_values" in rec and rec["skipped_c_values"]:
                            first_value = (
                                rec["skipped_c_values"][0]
                                if rec["skipped_c_values"]
                                else None
                            )
                            c_set_size = (
                                len(first_value)
                                if isinstance(first_value, tuple)
                                else 1
                            )
                        else:
                            # If C_set size cannot be determined, use "Unknown"
                            c_set_size = "Unknown"
                    except (KeyError, IndexError, TypeError):
                        c_set_size = "Unknown"

                    # Initialize metrics for this C_set size if it doesn't exist
                    if c_set_size not in est_c_set_sizes:
                        est_c_set_sizes[c_set_size] = {
                            "total": 0,
                            "passed": 0,
                            "failed": 0,
                            "skipped": 0,
                            "total_distributions": 0,
                            "total_skipped_distributions": 0,
                            "total_passed_distributions": 0,
                            "total_failed_distributions": 0,
                        }

                    # Increment total for this C_set size
                    est_c_set_sizes[c_set_size]["total"] += 1
                    est_overall_metrics["total"] += 1

                    # Check test result (independence accepted/rejected/skipped)
                    if "independence_accepted" in rec:
                        if rec["independence_accepted"] is None:
                            est_c_set_sizes[c_set_size]["skipped"] += 1
                            est_overall_metrics["skipped"] += 1
                        elif rec["independence_accepted"] is True:
                            est_c_set_sizes[c_set_size]["passed"] += 1
                            est_overall_metrics["passed"] += 1
                        else:
                            est_c_set_sizes[c_set_size]["failed"] += 1
                            est_overall_metrics["failed"] += 1

                    # Process C values and their JS divergences
                    if "C_values" in rec and "js_divergences" in rec:
                        threshold = rec.get("threshold", 0.05)
                        c_values = rec.get("C_values", [])
                        js_values = rec.get("js_divergences", [])

                        for i, js_val in enumerate(js_values):
                            if i < len(c_values):  # Make sure we're within bounds
                                est_c_set_sizes[c_set_size]["total_distributions"] += 1
                                est_overall_metrics["total_distributions"] += 1
                                if js_val <= threshold:
                                    est_c_set_sizes[c_set_size][
                                        "total_passed_distributions"
                                    ] += 1
                                    est_overall_metrics[
                                        "total_passed_distributions"
                                    ] += 1
                                else:
                                    est_c_set_sizes[c_set_size][
                                        "total_failed_distributions"
                                    ] += 1
                                    est_overall_metrics[
                                        "total_failed_distributions"
                                    ] += 1

                    # Count skipped distributions
                    if "skipped_c_values" in rec:
                        n_skipped = len(rec["skipped_c_values"])
                        est_c_set_sizes[c_set_size][
                            "total_skipped_distributions"
                        ] += n_skipped
                        est_overall_metrics["total_skipped_distributions"] += n_skipped
                        est_c_set_sizes[c_set_size]["total_distributions"] += n_skipped
                        est_overall_metrics["total_distributions"] += n_skipped

                # Print rows for each C_set size
                for c_size, metrics in sorted(est_c_set_sizes.items()):
                    # Calculate rejection rate
                    valid_tests = metrics["total"] - metrics["skipped"]
                    rejection_rate = (
                        (metrics["failed"] / valid_tests) * 100
                        if valid_tests > 0
                        else 0.0
                    )

                    # Construct label based on C_set size
                    if c_size == "Unknown":
                        label = "Estimator CI (Unknown C size)"
                    else:
                        label = f"Estimator CI (|C| = {c_size})"

                    # Print row
                    print(
                        row_format.format(
                            label,
                            str(metrics["total"]),
                            str(metrics["passed"]),
                            str(metrics["failed"]),
                            str(metrics["skipped"]),
                            str(metrics["total_distributions"]),
                            str(metrics["total_skipped_distributions"]),
                            str(metrics["total_passed_distributions"]),
                            str(metrics["total_failed_distributions"]),
                            f"{rejection_rate:.2f}%",
                        )
                    )

                # Print a separator line
                print("-" * 80)

                # Print overall row for Estimator CI
                valid_tests = (
                    est_overall_metrics["total"] - est_overall_metrics["skipped"]
                )
                rejection_rate = (
                    (est_overall_metrics["failed"] / valid_tests) * 100
                    if valid_tests > 0
                    else 0.0
                )

                print(
                    row_format.format(
                        "Estimator CI (All C sizes)",
                        str(est_overall_metrics["total"]),
                        str(est_overall_metrics["passed"]),
                        str(est_overall_metrics["failed"]),
                        str(est_overall_metrics["skipped"]),
                        str(est_overall_metrics["total_distributions"]),
                        str(est_overall_metrics["total_skipped_distributions"]),
                        str(est_overall_metrics["total_passed_distributions"]),
                        str(est_overall_metrics["total_failed_distributions"]),
                        f"{rejection_rate:.2f}%",
                    )
                )

        print("=" * 80)

    print_positive_percentage(
        l3_axioms_df, "composition_success_rate", "Composition Success Rate"
    )
    print_positive_percentage(
        l3_axioms_df, "effectiveness_success_rate", "Effectiveness Success Rate"
    )
    print_positive_percentage(
        l3_axioms_df, "reversibility_success_rate", "Reversibility Success Rate"
    )
    print("Done!")


if __name__ == "__main__":
    main()
