import random

import numpy as np


def verify_structural_counterfactual_axioms(
    scm, n_tests: int = 1, tolerance: float = 1e-5, seed: int = None
):
    """
    Tests Pearl's three structural counterfactual axioms on the given SCM.

    Axioms:
      1) Composition:
         If W_{do(X=x)}(u) = w, then
         Y_{do(X=x), do(W=w)}(u) = Y_{do(X=x)}(u).

      2) Effectiveness:
         X_{do(X=x), do(W=w)}(u) = x.

      3) Reversibility:
         If Y_{do(X=x), do(W=w)}(u) = y  AND  W_{do(X=x), do(Y=y)}(u) = w,
         then  Y_{do(X=x)}(u) = y.

    We do multiple random tests for each axiom, returning success rates.

    Parameters
    ----------
    scm : SCM
        A structural causal model object.

    n_tests : int
        How many random tests to attempt per axiom.
        Not needed because for 1 test we perform SCM.n_samples tests,
        and that is set to 100k by default (and is much faster because it's vectorized)

    tolerance : float
        Numerical tolerance for comparing continuous or discrete values.

    seed : int or None
        Random seed for reproducibility.

    Returns
    -------
    results_dict : dict
        {
          "composition_success_rate": float,
          "effectiveness_success_rate": float,
          "reversibility_success_rate": float,
          "num_tests_composition": int,
          "num_tests_effectiveness": int,
          "num_tests_reversibility": int
        }
    """
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    # Collect the endogenous variables (i.e. not exogenous).
    endogenous_vars = [var for var in scm.variables.values() if not var.exogenous]
    if len(endogenous_vars) < 3:
        print(
            "At least three endogenous variables are required to test the structural counterfactual axioms."
        )
        return {
            "composition_success_rate": np.nan,
            "effectiveness_success_rate": np.nan,
            "reversibility_success_rate": np.nan,
            "num_tests_composition": 0,
            "num_tests_effectiveness": 0,
            "num_tests_reversibility": 0,
        }

    # Helpers
    def values_close(v1, v2, tol=tolerance):
        """Check if two arrays or scalars are numerically close. Handles None gracefully."""
        if v1 is None or v2 is None:
            return False
        v1 = np.array(v1, dtype=float).ravel()
        v2 = np.array(v2, dtype=float).ravel()
        return np.allclose(v1, v2, atol=tol, rtol=0.0)

    def sample_positive_integers_max_sum(N, M):
        """
        Sample M positive numbers with a sum at most N.
        """
        if M > N:
            raise ValueError("M must be <= N")

        # First, decide a random total sum S (at most N)
        # Ensures at least M, at most N
        S = np.random.randint(M, N + 1)

        # Distribute S - M extra units over M numbers (they all start with 1).
        remainder = S - M  # This is the extra to be distributed

        # Choose M-1 breakpoints in range [0, R] to create partitions
        breaks = np.sort(np.random.choice(range(remainder + 1), M - 1, replace=True))

        # Compute consecutive differences and add 1 to ensure positivity
        diffs = np.diff(np.concatenate(([0], breaks, [remainder])))
        diffs += 1  # Ensure each value is at least 1
        return diffs

    ############################################################################
    # 1) Composition
    ############################################################################
    composition_passes, composition_tests = 0, 0

    # For Composition, we need sets X, Y, W.
    # Ensure Y \cap W = \emptyset.
    # Then we do the scenario:
    #   W_{do(X=x)}(u) = w  =>  Y_{do(X=x), do(W=w)}(u) = Y_{do(X=x)}(u).
    for _ in range(n_tests):
        # Randomly choose X, Y, W from the endogenous set
        set_sizes = sample_positive_integers_max_sum(len(endogenous_vars), 3)
        X = set(random.sample(endogenous_vars, set_sizes[0]))
        remaining_vars = [v for v in endogenous_vars if v not in X]
        W = set(random.sample(remaining_vars, set_sizes[1]))
        remaining_vars = [v for v in endogenous_vars if v not in X and v not in W]
        Y = set(random.sample(remaining_vars, set_sizes[2]))
        if not Y or not W:
            print("Composition: Y or W are empty, this might be a bug")
            continue

        # Sample new noise
        scm.reset_values(reset_noise=True)
        scm.sample_noise_variables()
        scm.compute_variables()

        # Pick values for X to do(X=x)
        x_vals = {}
        for x in X:
            x_value = np.copy(x.value)
            np.random.shuffle(x_value)
            x_vals[x] = x_value

        if any(v is None for v in x_vals.values()):
            continue

        # -----  W_{do(X=x)}(u) = w  -----
        # Apply do(X=x), read W
        scm.reset_values(reset_noise=False)  # keep noise, drop old endogenous values
        for x, val in x_vals.items():
            scm.do_hard_intervention(x, val)
        scm.compute_variables()
        w_vals = {w: w.value for w in W}
        y_vals_doX = {y: y.value for y in Y}  # Y_{do(X=x)}(u)
        # remove the intervention on X so we can do a fresh do(X=x), do(W=w)
        scm.remove_interventions(list(X))

        if any(v is None for v in w_vals.values()) or any(
            v is None for v in y_vals_doX.values()
        ):
            continue

        # -----  Y_{do(X=x), do(W=w)}(u)  -----
        # Apply do(X=x), do(W=w), read Y
        scm.reset_values(reset_noise=False)
        for x, val in x_vals.items():
            scm.do_hard_intervention(x, val)
        for w, val in w_vals.items():
            scm.do_hard_intervention(w, val)
        scm.compute_variables()
        y_vals_doXW = {y: y.value for y in Y}
        scm.remove_interventions(list(X) + list(W))

        # Compare
        for y in Y:
            y_x = y_vals_doX[y].flatten()  # Ensure (N,1) -> (N,)
            y_xw = y_vals_doXW[y].flatten()  # Ensure (N,1) -> (N,)

            # Count tests for each dimension
            composition_tests += len(y_x)

            # Count passes where values match within tolerance
            composition_passes += sum(
                values_close(y_x[i], y_xw[i]) for i in range(len(y_x))
            )

    composition_success_rate = (
        composition_passes / composition_tests if composition_tests > 0 else None
    )

    ############################################################################
    # 2) Effectiveness
    ############################################################################
    effectiveness_passes, effectiveness_tests = 0, 0

    # For Effectiveness, we pick X and W, do do(X=x), do(W=w), then check if X remains x.
    for _ in range(n_tests):
        # Ensure X \cap W = \emptyset
        set_sizes = sample_positive_integers_max_sum(len(endogenous_vars), 2)
        X = set(random.sample(endogenous_vars, set_sizes[0]))
        remaining_vars = [v for v in endogenous_vars if v not in X]
        W = set(random.sample(remaining_vars, set_sizes[1]))

        if not W:
            print("Effectiveness has empty W: might be a bug!")
            continue

        # Sample new noise
        scm.reset_values(reset_noise=True)
        scm.sample_noise_variables()
        scm.compute_variables()

        # Pick values for X and W
        x_vals = {}
        for x in X:
            x_value = np.copy(x.value)
            np.random.shuffle(x_value)
            x_vals[x] = x_value
        w_vals = {}
        for w in W:
            w_value = np.copy(w.value)
            np.random.shuffle(w_value)
            w_vals[w] = w_value
        if any(v is None for v in x_vals.values()) or any(
            v is None for v in w_vals.values()
        ):
            continue

        # Apply do(X=x), do(W=w)
        scm.reset_values(reset_noise=False)
        for x, val in x_vals.items():
            scm.do_hard_intervention(x, val)
        for w, val in w_vals.items():
            scm.do_hard_intervention(w, val)
        scm.compute_variables()

        x_vals_after = {x: x.value for x in X}
        scm.remove_interventions(list(X) + list(W))

        # Check consistency
        for x in X:
            x_before = x_vals[x].flatten()  # Ensure (N,1) -> (N,)
            x_after = x_vals_after[x].flatten()  # Ensure (N,1) -> (N,)

            # Count tests for each dimension
            effectiveness_tests += len(x_before)

            # Count passes where values match within tolerance
            effectiveness_passes += sum(
                values_close(x_before[i], x_after[i]) for i in range(len(x_before))
            )

    effectiveness_success_rate = (
        effectiveness_passes / effectiveness_tests if effectiveness_tests > 0 else None
    )

    ############################################################################
    # 3) Reversibility
    ############################################################################
    reversibility_passes = 0
    reversibility_tests = 0

    # Reversibility says:
    # If Y_{do(X=x), do(W=w)}(u) = y and W_{do(X=x), do(Y=y)}(u) = w, then Y_{do(X=x)}(u) = y.
    # This is a conditional statement, so only test the conclusion if the premise is true.
    for _ in range(n_tests):
        # Ensure X \cap {Y, W} = \emptyset and Y \neq W
        set_sizes = sample_positive_integers_max_sum(len(endogenous_vars) - 2, 1)
        X = set(random.sample(endogenous_vars, set_sizes[0]))
        remaining_vars = [v for v in endogenous_vars if v not in X]
        Y, W = random.sample(remaining_vars, 2)

        # Sample new noise
        scm.reset_values(reset_noise=True)
        scm.sample_noise_variables()
        scm.compute_variables()

        # Pick value for X
        x_vals = {}
        for x in X:
            x_value = np.copy(x.value)
            np.random.shuffle(x_value)
            x_vals[x] = x_value
        w_val = np.copy(w.value)
        np.random.shuffle(w_val)
        if any(v is None for v in x_vals.values()) or w_val is None:
            continue

        # Apply do(X=x), do(W=w), get Y
        scm.reset_values(reset_noise=False)
        for x, val in x_vals.items():
            scm.do_hard_intervention(x, val)
        scm.do_hard_intervention(W, w_val)
        scm.compute_variables()
        y_val = Y.value
        scm.remove_interventions(list(X) + [W])

        if y_val is None:
            continue

        # Apply do(X=x), do(Y=y), get W'
        scm.reset_values(reset_noise=False)
        for x, val in x_vals.items():
            scm.do_hard_intervention(x, val)
        scm.do_hard_intervention(Y, y_val)
        scm.compute_variables()
        w_val2 = W.value
        scm.remove_interventions(list(X) + [Y])

        # The premise is: Y_{do(X=x), do(W=w)}(u) = y  AND  W_{do(X=x), do(Y=y)}(u) = w
        # We have w_val_baseline for w, y_val for y, w_val2 for w'.
        # The premise is satisfied if w_val2 == w_val_baseline (within tolerance).
        valid_dims = [i for i in range(len(w_val)) if values_close(w_val[i], w_val2[i])]
        # If no dimensions satisfy the premise, skip testing
        if not valid_dims:
            continue

        # Now check conclusion: Y_{do(X=x)}(u) = y
        # Apply do(X=x), check Y
        scm.reset_values(reset_noise=False)
        for x, val in x_vals.items():
            scm.do_hard_intervention(x, val)
        scm.compute_variables()
        y_val2 = Y.value
        scm.remove_interventions(list(X))

        y_val = y_val.flatten()  # Ensure (N,1) -> (N,)
        y_val2 = y_val2.flatten()  # Ensure (N,1) -> (N,)

        # Count tests for each valid dimension
        reversibility_tests += len(valid_dims)

        # Count passes where values match within tolerance within the valid dimensions
        reversibility_passes += sum(
            values_close(y_val[i], y_val2[i]) for i in valid_dims
        )

    reversibility_success_rate = (
        reversibility_passes / reversibility_tests if reversibility_tests > 0 else None
    )

    results_dict = {
        "composition_success_rate": composition_success_rate,
        "effectiveness_success_rate": effectiveness_success_rate,
        "reversibility_success_rate": reversibility_success_rate,
        "num_tests_composition": composition_tests,
        "num_tests_effectiveness": effectiveness_tests,
        "num_tests_reversibility": reversibility_tests,
    }
    return results_dict


if __name__ == "__main__":
    """
    A main() function that:
    1) Builds multiple random SCMs (with different expected_edges).
    2) Verifies the structural counterfactual axioms on each SCM.
    3) Prints out the success rates and some debugging info.
    """
    import os
    import sys

    import torch

    current_dir = os.path.dirname(os.path.realpath(__file__))
    # Get the root directory of the project
    project_root = os.path.abspath(
        os.path.join(current_dir, os.pardir, "causal_profiler")
    )
    # Add the project root directory to the Python path
    sys.path.insert(0, project_root)

    from space_of_interest import (
        MechanismFamily,
        NoiseDistribution,
        QueryType,
        SpaceOfInterest,
        VariableDataType,
    )

    from causal_profiler import CausalProfiler, ErrorMetric

    # Fix a seed for reproducibility
    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    def create_test_scm(expected_edges):
        space = SpaceOfInterest(
            number_of_nodes=(4, 4),
            expected_edges=expected_edges,
            variable_dimensionality=(1, 1),
            mechanism_family=MechanismFamily.TABULAR,
            noise_distribution=NoiseDistribution.UNIFORM,
            noise_args=[0, 1],
            number_of_noise_regions="N",
            number_of_categories=2,
            variable_type=VariableDataType.DISCRETE,
            markovian=True,
            semi_markovian=False,
            causal_graph=None,
            control_positivity=False,
            number_of_queries=1,
            query_type=QueryType.CONDITIONAL,
            number_of_data_points=5,
        )

        profiler = CausalProfiler(
            space_of_interest=space,
            metric=ErrorMetric.L2,
            return_adjacency_matrix=False,
            n_samples=10000,
        )

        # Generate the data and thus create the SCM
        profiler.generate_samples_and_queries()
        return profiler.sampler._scm

    # Vary the edge densities to see how it behaves
    for expected_edges in [0, 2, 4]:
        print(f"\n=== Testing SCM with expected_edges = {expected_edges} ===")
        scm = create_test_scm(expected_edges)

        # Run the counterfactual axioms verification
        results = verify_structural_counterfactual_axioms(
            scm,
            tolerance=1e-7,
            seed=123,  # fix seed for reproducibility
        )

        print("Variables in this SCM:", list(scm.variables.keys()))
        print(
            "Endogenous variables:",
            [v.name for v in scm.variables.values() if not v.exogenous],
        )
        print(
            "Exogenous variables:",
            [v.name for v in scm.variables.values() if v.exogenous],
        )
        print("Results:\n", results)
        print("==============================================")
