from itertools import product
from typing import Dict, List, Optional

import numpy as np

if __name__ != "__main__":
    from causal_profiler.query import Query


def verify_query_estimator_conditional_independence(
    query_estimator,
    scm,
    A: str,
    B: str,
    C_set: List[str],
    # If provided, extract unique values, otherwise use `num_discrete_values`
    data_dict: Optional[Dict[str, np.ndarray]] = None,
    use_multi_query: bool = False,
    js_threshold: float = 0.05,
    aggregation_method: str = "mean",
    n_samples: int = 10000,
):
    """
    Check if A and B are conditionally independent given C_set using the Jensen-Shannon (JS) divergence.
    We query P(A, B | C) and P(A | C) * P(B | C) separately to check the estimator's correctness.
    This is meant to be called when we already know that A ind B | C from d-separation

    Parameters
    ----------
    query_estimator : object
        An estimator with methods:
            - evaluate_query(scm, query) -> float
            - evaluate_queries(scm, list_of_queries) -> List[float or np.nan]
        The latter uses the same sampled dataset for all queries, while the former
        resamples for each query.

    scm : SCM
        The structural causal model containing variables in `scm.variables[var_name]`.

    A, B : str
        Names of the two (categorical) variables whose conditional independence we test.

    C_set : List[str]
        List of conditioning variables.

    data_dict : dict, optional
        If provided, we extract variable domains from unique observed values in `data_dict`.
        Otherwise, we use `scm.variables[var_name].num_discrete_values`. The only utility of the
        dataset is to restrict the range of variables to an observed dataset.

    use_multi_query : bool, default True
        - If True, calls `evaluate_queries` once for all queries (faster, same dataset).
        - If False, calls `evaluate_query` repeatedly (different dataset each time).

    js_threshold : float, default 0.05
        A threshold to decide if JS divergence is "small enough" to consider A inp B|C.

    aggregation_method : {"max", "mean", "median"}
        How to aggregate the JS divergences over all c-tuples.

    n_samples: int, default 100k
        Number of samples used by query estimator

    Intuition when selecting the js_threshold and aggregation_method
    It also depends on how queries are evaluated & how much data is used to do so
    - High-variance estimator, avoid outliers: median + 0.10 js threshold
    - Balanced, practical: mean (smooth out noise) + 0.05 js threshold
    - Strict, catch all violations: max + 0.01 js threshold (never got this to work)

    Returns
    -------
    result_dict : dict
        {
            "C_values": list of all c-tuples tested,
            "skipped_c_values": list of c-tuples skipped due to NaN in queries,
            "js_divergences": list of JS divergences (one per tested c-tuple),
            "aggregation": aggregated JS (max, mean or median),
            "independence_accepted": bool or None,
            "threshold": float,
            "aggregation_method": str,
            "num_c_values": int,
            "num_queries": int
        }
    """
    query_estimator.n_samples = n_samples

    # Extract variable domains
    def get_domain(var_name: str):
        if data_dict is not None:
            # Use unique values from data_dict
            return sorted(np.unique(data_dict[var_name].flatten()))
        else:
            # Use SCM info, domain = [0, 1, ..., num_vals-1]
            num_vals = scm.variables[var_name].num_discrete_values
            return list(range(num_vals))

    A_domain = get_domain(A)
    B_domain = get_domain(B)
    C_domains = [get_domain(c_var) for c_var in C_set]

    # All combinations of c-values
    # e.g., if C_set = [C1, C2], get (c1_val, c2_val)...
    all_c_combinations = list(product(*C_domains))

    # Build queries for P(A, B | C) and P(A | C), P(B | C)
    # Specifically we add to the query_map joint queries for each (a_val, b_val, c_tuple),
    # and marginal queries for each (a_val, c_tuple) and (b_val, c_tuple).
    all_queries = []
    query_map = {}
    # Global query id
    query_id = 0

    # (a) Joint queries: P(A=a, B=b | C=c)
    for c_tuple in all_c_combinations:
        for a_val in A_domain:
            for b_val in B_domain:
                q = Query.createL1Conditional(
                    Y=scm.variables[A],
                    X=[scm.variables[B]] + [scm.variables[c] for c in C_set],
                    Y_value=a_val,
                    X_value=[b_val] + list(c_tuple),
                )
                all_queries.append(q)
                query_map[("joint", c_tuple, a_val, b_val)] = query_id
                query_id += 1

    # (b) Marginal queries for A: P(A=a | C=c)
    for c_tuple in all_c_combinations:
        for a_val in A_domain:
            q = Query.createL1Conditional(
                Y=scm.variables[A],
                X=[scm.variables[c] for c in C_set],
                Y_value=a_val,
                X_value=list(c_tuple),
            )
            all_queries.append(q)
            query_map[("A", c_tuple, a_val)] = query_id
            query_id += 1

    # (c) Marginal queries for B: P(B=b | C=c)
    for c_tuple in all_c_combinations:
        for b_val in B_domain:
            q = Query.createL1Conditional(
                Y=scm.variables[B],
                X=[scm.variables[c] for c in C_set],
                Y_value=b_val,
                X_value=list(c_tuple),
            )
            all_queries.append(q)
            query_map[("B", c_tuple, b_val)] = query_id
            query_id += 1

    # Evaluate the queries, either in bulk or one by one
    if use_multi_query:
        # Single call, same dataset
        query_results = query_estimator.evaluate_queries(scm, all_queries)
    else:
        # Multiple calls, each might sample a different dataset
        # Bigger variance but less correlation between the evaluations (bias/variance tradeoff)
        query_results = [query_estimator.evaluate_query(scm, q) for q in all_queries]

    # Convert probabilities to log probabilities
    eps = 1e-12
    log_results = np.log(np.maximum(query_results, eps))

    def js_divergence(p, q):
        """Computes JS divergence between two probability distributions."""
        eps = 1e-12
        p = np.maximum(p, eps)
        q = np.maximum(q, eps)
        p /= p.sum()
        q /= q.sum()
        m = 0.5 * (p + q)

        def kl_div(x, y):
            return np.sum(x * np.log(x / y))

        return 0.5 * kl_div(p, m) + 0.5 * kl_div(q, m)

    # For each c_tuple, build the joint distribution & product-of-marginals
    js_values = []
    skipped_c_values = []
    for c_tuple in all_c_combinations:
        # Reconstruct the joint distribution P(A=a, B=b | C=c_tuple) from the queries
        joint_dist = np.zeros((len(A_domain), len(B_domain)), dtype=float)
        # We'll also build pA and pB from separate queries in logspace
        log_pA = np.full(len(A_domain), np.nan)
        log_pB = np.full(len(B_domain), np.nan)

        # (a) Fill joint_dist
        # If any query is NaN, skip entire c_tuple, TODO: is this the right behavior?
        skip = False
        for a_i, a_val in enumerate(A_domain):
            for b_j, b_val in enumerate(B_domain):
                idx = query_map[("joint", c_tuple, a_val, b_val)]
                val = query_results[idx]
                if np.isnan(val):
                    skip = True
                    break
                joint_dist[a_i, b_j] = val
            if skip:
                break
        if skip:
            skipped_c_values.append(c_tuple)
            continue

        # (b) Build P(A=a|C=c_tuple) from separate queries
        for a_i, a_val in enumerate(A_domain):
            idx = query_map[("A", c_tuple, a_val)]
            val = log_results[idx]
            if np.isnan(val):
                skip = True
                break
            log_pA[a_i] = val
        if skip:
            skipped_c_values.append(c_tuple)
            continue

        # (c) Build P(B=b|C=c_tuple) from separate queries
        for b_j, b_val in enumerate(B_domain):
            idx = query_map[("B", c_tuple, b_val)]
            val = log_results[idx]
            if np.isnan(val):
                skip = True
                break
            log_pB[b_j] = val
        if skip:
            skipped_c_values.append(c_tuple)
            continue

        # Compute product in log-space, then exponentiate
        log_product_dist = log_pA[:, None] + log_pB[None, :]
        product_dist = np.exp(log_product_dist)

        # (e) Flatten & compute JS
        js_val = js_divergence(joint_dist.flatten(), product_dist.flatten())
        js_values.append(js_val)

    # Aggregate results
    aggregation_func = {"max": max, "mean": np.mean, "median": np.median}
    if len(js_values) > 0:
        aggregated_js = aggregation_func[aggregation_method](js_values)
        independence_accepted = aggregated_js <= js_threshold
    else:
        # If everything was skipped, aggregated_js is NaN, and we can't decide
        aggregated_js = np.nan
        independence_accepted = None

    # Return final dictionary
    return {
        "C_values": all_c_combinations,
        "skipped_c_values": skipped_c_values,
        "js_divergences": js_values,
        "aggregation": aggregated_js,
        "independence_accepted": independence_accepted,
        "threshold": js_threshold,
        "aggregation_method": aggregation_method,
        "num_c_values": len(all_c_combinations),
        "num_queries": len(all_queries),
    }


if __name__ == "__main__":
    import os
    import random
    import sys

    import torch

    current_dir = os.path.dirname(os.path.realpath(__file__))
    # Get the root directory of the project
    project_root = os.path.abspath(
        os.path.join(current_dir, os.pardir, "causal_profiler")
    )
    # Add the project root directory to the Python path
    sys.path.insert(0, project_root)

    from query import Query
    from space_of_interest import (
        MechanismFamily,
        NoiseDistribution,
        QueryType,
        SpaceOfInterest,
        VariableDataType,
    )

    from causal_profiler import CausalProfiler, ErrorMetric

    # Set seed for reproducility
    SEED = 43  # TODO: big variance, seed = 42 rejects it
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Generate an SCM, data, and a query estimator
    space = SpaceOfInterest(
        number_of_nodes=(4, 4),
        expected_edges=0,  # 2 TODO
        variable_dimensionality=(1, 1),
        mechanism_family=MechanismFamily.TABULAR,
        noise_distribution=NoiseDistribution.UNIFORM,
        noise_args=[-1, 1],
        number_of_noise_regions="N",
        number_of_categories=2,
        variable_type=VariableDataType.DISCRETE,
        markovian=True,
        semi_markovian=False,
        causal_graph=None,
        control_positivity=False,
        number_of_queries=1,
        query_type=QueryType.CONDITIONAL,
        number_of_data_points=50,
    )
    profiler = CausalProfiler(
        space_of_interest=space,
        metric=ErrorMetric.L2,
        return_adjacency_matrix=False,
    )

    (data_dict, (_, _), (_, index_to_var)) = profiler.generate_samples_and_queries()
    scm = profiler.sampler._scm
    query_estimator = profiler.sampler.query_estimator

    # index_to_var is something like ["X1", "X2", "X3", "X4"]
    # Let's pick A, B, and everything else in C
    A, B, *C = index_to_var

    # Call the function
    result = verify_query_estimator_conditional_independence(
        query_estimator=query_estimator,
        scm=scm,
        A=A,
        B=B,
        C_set=C,
        data_dict=data_dict,  # so we get discrete domains from the actual data
        use_multi_query=False,  # not a single dataset
        js_threshold=0.05,
        aggregation_method="max",
    )

    print("RESULT:")
    print("JS divergences:", result["js_divergences"])
    print("Aggregated JS:", result["aggregation"])
    print("Independence accepted?", result["independence_accepted"])
    print("Skipped c-values:", result["skipped_c_values"])
