from itertools import product
from typing import Dict, List, Union

import numpy as np
import pandas as pd

if __name__ != "__main__":
    from causal_profiler.query import Query


def verify_query_estimator_conditionals(
    query_estimator,
    scm,
    A: str,
    C: Union[str, List[str]],
    data_dict: Dict[str, np.ndarray],
    use_multi_query: bool = False,
    js_threshold: float = 0.05,
    aggregation_method: str = "mean",
    n_samples: int = 10000,
    min_count: int = 1,
):
    """
    Compare the empirical P(A | C) from data_dict with the query-estimator's P(A | C).

    For each c-value (or c-tuple), we:
      - Compute empirical distribution of A given c from data_dict.
      - Query the estimator for P(A=a | C=c).
      - Compute JS divergence between the two distributions.

    Finally we aggregate these JS values across all c-values using 'mean', 'median', or 'max',
    and compare to a js_threshold to decide whether "agreement" is accepted.

    Parameters
    ----------
    query_estimator : object
        An estimator with methods:
            - evaluate_query(scm, query) -> float
            - evaluate_queries(scm, list_of_queries) -> List[float or np.nan]
        The latter uses the same sampled dataset for all queries, while the former
        resamples for each query.

    scm : object
        The structural causal model containing variables in `scm.variables[var_name]`.

    A : str
        Name of the (categorical) variable whose conditional distribution we are checking.

    C : str or List[str]
        Name(s) of the conditioning variable(s). May be a single string or a list of strings.

    data_dict : dict
        Dictionary of {variable_name: np.ndarray}, where each array is shape [n_samples].

    use_multi_query : bool, optional (default=True)
        If True, calls `evaluate_queries` once for all queries (faster, single shared dataset) - TODO
        If False, calls `evaluate_query` repeatedly for each query (more independent draws).

    js_threshold : float, optional (default=0.05)
        Jensen-Shannon divergence threshold to decide if the distributions "agree".

    aggregation_method : {'mean', 'median', 'max'}, optional (default='mean')
        How to aggregate the JS divergences across all c-values.

    n_samples : int, optional (default=10000)
        Number of samples used by the query estimator.

    min_count : int, optional (default=1)
        Minimum number of samples in data_dict for a particular c-value combination
        to consider that distribution valid. If fewer than `min_count` samples appear,
        we skip it.

    Returns
    -------
    result_dict : dict
        {
            "C_values": List of all c-tuples (even if single variable, always stored as tuples),
            "skipped_c_values": c-tuples skipped due to no data or invalid queries,
            "js_divergences": list of JS divergences for each tested c-tuple,
            "aggregation": float, aggregated JS,
            "agreement_accepted": bool or None,
            "threshold": float,
            "aggregation_method": str,
            "num_c_values": int,
            "num_queries": int
        }
    """

    # Ensure C is a list
    if isinstance(C, str):
        C = [C]
    C_set = C

    query_estimator.n_samples = n_samples

    # -----------------------------
    # Get domain for A and for each variable in C
    # -----------------------------
    def get_domain(var_name: str):
        if var_name not in data_dict:
            # fallback: use the SCM's declared domain
            print(f"WARNING: {var_name} not in data_dict (this is probably wrong)")
            num_vals = scm.variables[var_name].num_discrete_values
            return list(range(num_vals))
        # else extract from the data
        return sorted(np.unique(data_dict[var_name].flatten()))

    A_domain = get_domain(A)
    C_domains = [get_domain(c_var) for c_var in C_set]
    # All possible c-tuples
    all_c_combinations = list(product(*C_domains))

    # -----------------------------
    # Build empirical distribution P(A=a | C=c) from data_dict
    # -----------------------------
    # Put data in a DataFrame for easier filtering
    df = pd.DataFrame({A: data_dict[A].flatten()})
    for c_var in C_set:
        df[c_var] = data_dict[c_var].flatten()

    # Store empirical distributions in a dict: emp_cond[(c_tuple)] -> np.array over A_domain
    emp_cond = {}

    # We create a map from A-values to index, for easy counting
    A_val_to_idx = {val: i for i, val in enumerate(A_domain)}

    for c_tuple in all_c_combinations:
        # Filter rows where C == c_tuple
        mask = np.ones(len(df), dtype=bool)
        for c_var, c_val in zip(C_set, c_tuple):
            mask &= df[c_var] == c_val

        sub_df = df[mask]
        if len(sub_df) < min_count:
            # not enough data for stable distribution
            continue

        # Count how many times each A-value appears
        counts = np.zeros(len(A_domain), dtype=float)
        for val in sub_df[A]:
            idx = A_val_to_idx.get(val, None)
            if idx is not None:
                counts[idx] += 1

        # Normalize
        if counts.sum() > 0:
            counts /= counts.sum()
            # P(A|C) = P(A, C) / P(C)
            emp_cond[c_tuple] = counts

    # -----------------------------
    # Query the estimator for P(A=a | C=c)
    # -----------------------------
    # Build queries: for each c_tuple and a_val

    all_queries = []
    query_map = {}
    query_id = 0

    for c_tuple in all_c_combinations:
        # Only bother if we have an empirical distribution
        if c_tuple not in emp_cond:
            continue
        # Build queries for each A-value
        for a_val in A_domain:
            q = Query.createL1Conditional(
                Y=scm.variables[A],
                X=[scm.variables[c] for c in C_set],
                Y_value=a_val,
                X_value=list(c_tuple),
            )
            all_queries.append(q)
            query_map[(c_tuple, a_val)] = query_id
            query_id += 1

    # Evaluate queries
    if use_multi_query:
        query_results = query_estimator.evaluate_queries(scm, all_queries)
    else:
        # Potentially higher-variance, but each query is from an independent sample
        query_results = [query_estimator.evaluate_query(scm, q) for q in all_queries]

    # -----------------------------
    # Compare the two distributions (JS) for each c_tuple
    # -----------------------------
    def js_divergence(p, q):
        """
        Computes the Jensen-Shannon divergence between two discrete distributions p and q.
        """
        eps = 1e-12
        p = np.maximum(p, eps)
        q = np.maximum(q, eps)
        p /= p.sum()
        q /= q.sum()
        m = 0.5 * (p + q)

        def kl_div(x, y):
            return np.sum(x * np.log(x / y))

        return 0.5 * kl_div(p, m) + 0.5 * kl_div(q, m)

    js_values = []
    skipped_c_values = []

    for c_tuple in emp_cond.keys():
        # Reconstruct the distribution from query-estimator
        est_dist = np.zeros(len(A_domain), dtype=float)

        skip = False
        for i, a_val in enumerate(A_domain):
            q_idx = query_map.get((c_tuple, a_val), None)
            # In principle we have a query for each (c_tuple, a_val)
            if q_idx is None:
                # Shouldn't happen unless the domain or data are inconsistent
                # e.g., for small datasets
                skip = True
                break
            val = query_results[q_idx]
            if np.isnan(val):
                skip = True
                break
            est_dist[i] = val

        if skip:
            skipped_c_values.append(c_tuple)
            continue

        # Compare with empirical distribution
        emp_dist = emp_cond[c_tuple]
        js = js_divergence(emp_dist, est_dist)
        js_values.append(js)

    # -----------------------------
    # Aggregate across all c-tuples
    # -----------------------------
    aggregation_func = {"max": max, "mean": np.mean, "median": np.median}

    if len(js_values) == 0:
        # If we have no valid c-tuples or everything was skipped, we can't decide
        aggregated_js = np.nan
        agreement_accepted = None
    else:
        aggregated_js = aggregation_func[aggregation_method](js_values)
        agreement_accepted = aggregated_js <= js_threshold

    # -----------------------------
    # Construct and return result dict
    # -----------------------------
    result_dict = {
        "C_values": all_c_combinations,  # all possible c-tuples
        "skipped_c_values": skipped_c_values,  # c-tuples we skipped
        "js_divergences": js_values,  # one per tested c-tuple
        "aggregation": aggregated_js,  # aggregated JS
        "agreement_accepted": agreement_accepted,  # True/False/None
        "threshold": js_threshold,
        "aggregation_method": aggregation_method,
        "num_c_values": len(all_c_combinations),
        "num_queries": len(all_queries),
    }
    return result_dict


if __name__ == "__main__":
    import os
    import random
    import sys

    import numpy as np
    import torch

    from causal_profiler.query import Query
    from causal_profiler.space_of_interest import (
        MechanismFamily,
        NoiseDistribution,
        QueryType,
        SpaceOfInterest,
        VariableDataType,
    )

    from causal_profiler import CausalProfiler, ErrorMetric

    # Set seed for reproducibility
    SEED = 43
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Generate an SCM, data, and a query estimator
    space = SpaceOfInterest(
        number_of_nodes=(4, 4),
        expected_edges=4,
        variable_dimensionality=(1, 1),
        mechanism_family=MechanismFamily.TABULAR,
        noise_distribution=NoiseDistribution.UNIFORM,
        noise_args=[-1, 1],
        number_of_noise_regions="N",
        number_of_categories=3,
        variable_type=VariableDataType.DISCRETE,
        number_of_queries=1,
        query_type=QueryType.CONDITIONAL,
        number_of_data_points=500,
    )
    profiler = CausalProfiler(
        space_of_interest=space,
        metric=ErrorMetric.L2,
        return_adjacency_matrix=False,
    )

    (data_dict, (_, _), (_, index_to_var)) = profiler.generate_samples_and_queries()
    scm = profiler.sampler._scm
    query_estimator = profiler.sampler.query_estimator

    # index_to_var is something like ["X1", "X2", "X3", "X4"]
    # Pick A and everything else in C
    A, *C = index_to_var

    # Call the function
    result = verify_query_estimator_conditionals(
        query_estimator=query_estimator,
        scm=scm,
        A=A,
        C=C,
        data_dict=data_dict,  # Use observed data to get discrete domains
        use_multi_query=False,  # Batch evaluation isn't implemented yet. TODO
        js_threshold=0.05,
        aggregation_method="mean",
        n_samples=10000,
        min_count=1,  # Require at least 1 sample per C-tuple
    )

    print("\nRESULT:")
    print("JS divergences:", result["js_divergences"])
    print("Aggregated JS:", result["aggregation"])
    print("Agreement accepted?", result["agreement_accepted"])
    print("Skipped c-values:", result["skipped_c_values"])
