import statistics
import networkx as nx
import pandas as pd
import numpy as np
from scipy.stats import entropy, pearsonr, spearmanr

from causal_profiler.scm import SCM


def get_endogenous_perents(scm: SCM, var_id: str):
    """Get the list of endogenous parents of the variable var_id in the SCM scm

    Args:
        scm (SCM): the SCM of interest
        var_id (str): the name of the variable of interest

    Returns:
        list: the list of the endogenous parents' names
    """
    parents_list = []
    for pa in scm.parents[var_id]:
        if not scm.variables[pa].exogenous:
            parents_list.append(pa)
    return parents_list


def compute_marginal_distribution(samples: pd.DataFrame, var_id: str):
    """Compute the marginal distribution of the variable var_id following the samples in the dataframe samples.

    Args:
        samples (pd.DataFrame): the samples to compute the marginal distribution
        var_id (str): the varaible of interest

    Returns:
        pd.DataFrame: the marginal distribution of the variable. The column var_id stores the varaible values and
                      the column 'proba' stores the probability.
    """
    marginal_distrib = samples.groupby(var_id).size().reset_index(name="proba")
    marginal_distrib["proba"] = (
        marginal_distrib["proba"] / marginal_distrib["proba"].sum()
    )
    return marginal_distrib


def get_variables_discrete_values(scm: SCM, var_id: str):
    """Computes the domaine of definition of the variable var_id in the SCM scm.

    Args:
        scm (SCM): the SCM of interest
        var_id (str): the name of the variable of interest

    Returns:
        list: the list of discrete values on which the variable is defined according to the SCM
    """
    var = scm.variables[var_id]
    if not var.exogenous:
        var_discrete_values_list = [i for i in range(var.num_discrete_values)]
    else:
        var_discrete_values_list = []
        inf = 0
        for noise_region in var.noise_regions:
            var_discrete_values_list.append((inf + noise_region) / 2)
        var_discrete_values_list.append((var.noise_regions[-1] + 1) / 2)
    return var_discrete_values_list


def get_DAG_metrics(scm: SCM):
    """
    Computes several graph-based metrics on the input SCM causal graph (DAG over the observed and unobserved varaibles):
      - average and variance of degrees (undirected-degree, in-degree and out-degree)
      - average and variance number of ancestors
      - average and variance number of descendants
      - average, variance and max causal path length
      - average, variance and max open path length (TODO)

    Returns a dictionary with these metrics.
    """
    # 1) Build the SCM's DAG over the observed and unobserved endogeneous varaibles
    G = nx.DiGraph()
    for var_id in scm.variables:
        if not scm.variables[var_id].exogenous:
            G.add_node(var_id)
    for parent_id, child_ids in scm.children.items():
        if not scm.variables[parent_id].exogenous:
            for child_id in child_ids:
                if not scm.variables[child_id].exogenous:
                    G.add_edge(parent_id, child_id)

    # 2) Degree analysis
    und_degrees = [val for _, val in G.degree()]
    in_degrees = [val for _, val in G.in_degree()]
    out_degrees = [val for _, val in G.out_degree()]
    avg_und_degrees = statistics.mean(und_degrees)
    avg_in_degrees = statistics.mean(in_degrees)
    avg_out_degrees = statistics.mean(out_degrees)
    variance_of_und_degrees = (
        0
        if len(und_degrees) < 2
        else float(
            statistics.mean([d**2 for d in und_degrees])
            - statistics.mean(und_degrees) ** 2
        )
    )
    variance_of_in_degrees = (
        0
        if len(in_degrees) < 2
        else float(
            statistics.mean([d**2 for d in in_degrees])
            - statistics.mean(in_degrees) ** 2
        )
    )
    variance_of_out_degrees = (
        0
        if len(out_degrees) < 2
        else float(
            statistics.mean([d**2 for d in out_degrees])
            - statistics.mean(out_degrees) ** 2
        )
    )

    # 3) Ancestors / Descendants
    num_ancestors_list = []
    num_descendants_list = []
    for node in G.nodes():
        anc = nx.ancestors(G, node)
        desc = nx.descendants(G, node)
        num_ancestors_list.append(len(anc))
        num_descendants_list.append(len(desc))
    avg_num_ancestors = float(sum(num_ancestors_list)) / max(len(num_ancestors_list), 1)
    var_num_ancestors = (
        0
        if len(num_ancestors_list) < 2
        else float(
            statistics.mean([x**2 for x in num_ancestors_list])
            - statistics.mean(num_ancestors_list) ** 2
        )
    )
    avg_num_descendants = float(sum(num_descendants_list)) / max(
        len(num_descendants_list), 1
    )
    var_num_descendants = (
        0
        if len(num_descendants_list) < 2
        else float(
            statistics.mean([x**2 for x in num_descendants_list])
            - statistics.mean(num_descendants_list) ** 2
        )
    )

    # 4) Causal path length
    causal_path_lengths = []
    for node in G.nodes():
        # BFS or single-source shortest paths
        sp = nx.single_source_shortest_path_length(G, node)
        # exclude the node itself
        for target, dist in sp.items():
            if target != node:
                causal_path_lengths.append(dist)
    avg_causal_path_len = (
        float(sum(causal_path_lengths)) / max(len(causal_path_lengths), 1)
        if causal_path_lengths
        else 0.0
    )
    var_causal_path_len = (
        0
        if len(causal_path_lengths) < 2
        else float(
            statistics.mean([x**2 for x in causal_path_lengths])
            - statistics.mean(causal_path_lengths) ** 2
        )
    )
    try:
        longest_causal_path_len = nx.dag_longest_path_length(G)
    except nx.NetworkXUnfeasible:
        # If the graph has cycles, dag_longest_path_length() will fail
        longest_causal_path_len = -1

    # 5) Open path length TODO
    # list the colliders, for each collider and for each incoming arrow to this collider, hide this arrow and compute the open paths
    # store all these paths together AND remove duplicated paths
    # compute max, avg, var

    return {
        "avg_und_degrees": avg_und_degrees,
        "variance_of_und_degrees": variance_of_und_degrees,
        "avg_in_degrees": avg_in_degrees,
        "variance_of_in_degrees": variance_of_in_degrees,
        "avg_out_degrees": avg_out_degrees,
        "variance_of_out_degrees": variance_of_out_degrees,
        "avg_num_ancestors": avg_num_ancestors,
        "var_num_ancestors": var_num_ancestors,
        "avg_num_descendants": avg_num_descendants,
        "var_num_descendants": var_num_descendants,
        "longest_causal_path_length": longest_causal_path_len,
        "avg_causal_path_len": avg_causal_path_len,
        "var_causal_path_len": var_causal_path_len,
    }


def get_ADMG_metrics(scm):
    """
    Computes several graph-based metrics on the input SCM causal graph projected on the observed variables only (ADMG):
      - number of maximal confounded components (called c-components)
      - average, variance and max size of the maximal c-components
      - average and variance number of siblings (two varaibles are siblings if they are linked by a bidirected edge)

    Returns a dictionary with these metrics.
    """
    # 1) Build a NetworkX DiGraph with hidden confounders from the SCM's 'adjacency_list' projected on the observed variables
    G = nx.DiGraph()
    adjacency_dict, _ = scm.extract_adjacency_list()
    if len(adjacency_dict) < 2:
        return {
            "num_c_components": 0,
            "largest_c_comp_size": 0,
            "avg_c_comp_size": 0,
            "var_c_comp_size": 0,
            "avg_num_siblings": 0,
            "var_num_siblings": 0,
        }
    else:
        nodes_list = []
        hidden_confounders_list = []
        hidden_confounder_index = 0
        for parent_id, child_ids in adjacency_dict.items():
            nodes_list.append(parent_id)
            G.add_node(parent_id)
            for child_id in child_ids:
                if G.has_edge(
                    child_id, parent_id
                ):  # there is a bidirected edge between child_id and parent_id
                    # remove existing edge
                    G.remove_edge(child_id, parent_id)
                    # add hidden confounder
                    hidden_conf = "U_" + str(hidden_confounder_index)
                    hidden_confounders_list.append(hidden_conf)
                    G.add_edge(hidden_conf, child_id)
                    G.add_edge(hidden_conf, parent_id)
                    hidden_confounder_index += 1
                else:
                    G.add_edge(parent_id, child_id)

        # 2) Find the maximal c-components
        # Create an undirected subgraph of G, called G_ccomp, containing only the hidden confounders' edges
        G_ccomp = nx.Graph()
        for hidden_conf in hidden_confounders_list:
            for child1 in G.successors(hidden_conf):
                for child2 in G.successors(hidden_conf):
                    if child1 != child2:
                        G_ccomp.add_edge(child1, child2)
        if G_ccomp.number_of_edges() < 1:
            return {
                "num_c_components": 0,
                "largest_c_comp_size": 0,
                "avg_c_comp_size": 0,
                "var_c_comp_size": 0,
                "avg_num_siblings": 0,
                "var_num_siblings": 0,
            }
        else:
            # Find the maximal complete subgraphs of G_ccomp
            ccomps_list = list(nx.find_cliques(G_ccomp))
            ccomps_sizes_list = [len(ccomp) for ccomp in ccomps_list]
            # Compute metrics
            num_c_components = len(ccomps_list)
            largest_c_comp_size = (
                max(ccomps_sizes_list) if len(ccomps_sizes_list) > 0 else 0
            )
            avg_c_comp_size = float(sum(ccomps_sizes_list)) / max(
                len(ccomps_sizes_list), 1
            )
            var_c_comp_size = (
                0
                if len(ccomps_sizes_list) < 2
                else float(
                    statistics.mean([x**2 for x in ccomps_sizes_list])
                    - statistics.mean(ccomps_sizes_list) ** 2
                )
            )

            # 1) Siblings: share at least one hidden confounder
            sibling_counts = []
            for node in nodes_list:
                # Find parents of 'node'
                parents_list = G.predecessors(node)
                hidden_conf_parents_list = []
                # Keep only hidden confounders
                for parent in parents_list:
                    if parent in hidden_confounders_list:
                        hidden_conf_parents_list.append(parent)
                # Count siblings
                count_siblings = 0
                for other in nodes_list:
                    if other == node:
                        continue
                    other_parents = set(G.predecessors(other))
                    if (
                        set(hidden_conf_parents_list) & other_parents
                    ):  # non-empty intersection
                        count_siblings += 1
                sibling_counts.append(count_siblings)
            avg_num_siblings = float(sum(sibling_counts)) / max(len(sibling_counts), 1)
            var_num_siblings = (
                0
                if len(sibling_counts) < 2
                else float(
                    statistics.mean([x**2 for x in sibling_counts])
                    - statistics.mean(sibling_counts) ** 2
                )
            )

            return {
                "num_c_components": num_c_components,
                "largest_c_comp_size": largest_c_comp_size,
                "avg_c_comp_size": avg_c_comp_size,
                "var_c_comp_size": var_c_comp_size,
                "avg_num_siblings": avg_num_siblings,
                "var_num_siblings": var_num_siblings,
            }


def get_distrib_metrics(scm: SCM, n_samples_for_computation: int = 1000000):
    """
    Computes several metrics on the observational distribution of the SCM:
      - level of strong positivity (i.e., the min value of the joint probability and the proportion of null proba events)
      - average and variance level of weak positivity (i.e., the min value of the marginal probability of each variable)
      - average and variance level of maginal imbalanced distribution (i.e., L1 distance to uniform distribution)
      - level of joint imbalanced distribution (i.e., L1 distance to uniform distribution)

    Returns a dictionary with these metrics.
    """
    weak_pos_list = []
    marginal_imbalanced_list = []
    entropy_list = []

    # 1) Sample observational data
    sampled_data = scm.sample_data(n_samples_for_computation)
    endo_vars_list = []
    for var_id in sampled_data.keys():
        flattened_sampled_data = sampled_data[var_id].flatten()
        sampled_data[var_id] = flattened_sampled_data
        if not scm.variables[var_id].exogenous:
            endo_vars_list.append(var_id)
    sampled_data_df = pd.DataFrame.from_dict(sampled_data)[endo_vars_list]

    # 2) Compute conditional stochasticity metric
    for var_id in scm.variables:
        if not scm.variables[var_id].exogenous:
            endo_parents_list = get_endogenous_perents(scm, var_id)
            prob_table = (
                sampled_data_df.groupby([var_id] + endo_parents_list)
                .size()
                .reset_index(name="joint_prob")
            )
            prob_table = pd.merge(
                prob_table,
                sampled_data_df.groupby(var_id)
                .size()
                .reset_index(name="marginal_prob"),
                on=var_id,
                how="left",
            )
            prob_table["joint_prob"] = (
                prob_table["joint_prob"] / prob_table["joint_prob"].sum()
            )
            prob_table["marginal_prob"] = (
                prob_table["marginal_prob"] / prob_table["marginal_prob"].sum()
            )
            entropy_list.append(
                entropy(prob_table["joint_prob"], prob_table["marginal_prob"])
            )

    # 3) Compute the joint probability table
    joint_proba_table = (
        sampled_data_df.groupby(endo_vars_list).size().reset_index(name="joint_prob")
    )
    joint_proba_table["joint_prob"] = (
        joint_proba_table["joint_prob"] / joint_proba_table["joint_prob"].sum()
    )
    # Compute the cartesian product of realizations
    marginal_distrib = compute_marginal_distribution(sampled_data_df, endo_vars_list[0])
    # while computing the cartesian product, we can compute the marginal distributions
    weak_pos_list.append(marginal_distrib["proba"].min())
    marginal_imbalanced_list.append(
        np.mean(
            np.abs(marginal_distrib["proba"].values - 1 / marginal_distrib.shape[0])
        )
    )
    cartesian_prod_table = marginal_distrib.drop(columns=["proba"])
    for var_id in endo_vars_list[1:]:
        # while computing the cartesian product, we can compute the marginal distributions
        marginal_distrib = compute_marginal_distribution(sampled_data_df, var_id)
        weak_pos_list.append(marginal_distrib["proba"].min())
        marginal_imbalanced_list.append(
            np.mean(
                np.abs(marginal_distrib["proba"].values - 1 / marginal_distrib.shape[0])
            )
        )
        # update cartesian product
        cartesian_prod_table = cartesian_prod_table.merge(
            marginal_distrib.drop(columns=["proba"]), how="cross"
        )
    # Merge the cartesian product to the proba table
    full_joint_proba_table = cartesian_prod_table.merge(
        joint_proba_table, how="left"
    ).fillna(0)

    # 4) Compute strong positivity and joint imbalanced distribution
    strong_pos_measure_min = full_joint_proba_table["joint_prob"].min()
    strong_pos_measure_prop = (
        1
        - len(full_joint_proba_table[full_joint_proba_table["joint_prob"] == 0])
        / full_joint_proba_table.shape[0]
    )
    joint_imbalanced_measure = np.mean(
        np.abs(
            full_joint_proba_table["joint_prob"].values
            - 1 / full_joint_proba_table.shape[0]
        )
    )

    return {
        "strong_pos_measure_min": strong_pos_measure_min,
        "strong_pos_measure_prop": strong_pos_measure_prop,
        "joint_imbalanced_measure": joint_imbalanced_measure,
        "avg_weak_pos": statistics.mean(weak_pos_list),
        "var_weak_pos": (
            0
            if len(weak_pos_list) < 2
            else float(
                statistics.mean([d**2 for d in weak_pos_list])
                - statistics.mean(weak_pos_list) ** 2
            )
        ),
        "avg_marginal_imbalanced": statistics.mean(marginal_imbalanced_list),
        "var_marginal_imbalanced": (
            0
            if len(marginal_imbalanced_list) < 2
            else float(
                statistics.mean([d**2 for d in marginal_imbalanced_list])
                - statistics.mean(marginal_imbalanced_list) ** 2
            )
        ),
        "avg_conditional_entropy": statistics.mean(entropy_list),
        "var_conditional_entropy": (
            0
            if len(entropy_list) < 2
            else float(
                statistics.mean([d**2 for d in entropy_list])
                - statistics.mean(entropy_list) ** 2
            )
        ),
        "total_entropy": entropy(joint_proba_table["joint_prob"]),
    }


def get_mechanism_metrics(scm: SCM):
    """
    Computes several metrics on the causal mechanisms of the SCM:
      - average and variance level of linearity (i.e., pearson correlation between each parent-child pair
        including endogenous and exogenous variables)
      - average and variance level of monotonicity (i.e., spearman correlation between each parent-child pair
        including endogenous and exogenous variables)
      - average and variance level of injectivity w.r.t. the exogenous variables (i.e., ) (TODO)
      - average and variance level of surjectiviy w.r.t. the exogenous variables (i.e., ) (TODO)
      - average and variance level of stochasticity (i.e., the conditional entropy of a child given its endogenous parents)

    Returns a dictionary with these metrics.
    """
    # Iterate over the endogenous variables
    entropy_list = []
    linearity_list = []
    monotonicity_list = []
    for var_id in scm.variables:
        if not scm.variables[var_id].exogenous:

            # 1) Set parents values to the cartesian product of parents ranges
            input_values_dict = {}
            input_values_dict[scm.parents[var_id][0]] = get_variables_discrete_values(
                scm, scm.parents[var_id][0]
            )
            cartesian_prod_df = pd.DataFrame.from_dict(input_values_dict)
            for pa_id in scm.parents[var_id][1:]:
                input_values_dict = {}
                input_values_dict[pa_id] = get_variables_discrete_values(scm, pa_id)
                cartesian_prod_df = cartesian_prod_df.merge(
                    pd.DataFrame.from_dict(input_values_dict), how="cross"
                )
            for pa_id in scm.parents[var_id]:
                scm.variables[pa_id].value = np.expand_dims(
                    cartesian_prod_df[pa_id].values, axis=1
                )

            # 2) Apply mechanisms
            mechanism = scm.mechanisms.get(var_id)
            cartesian_prod_df[var_id] = mechanism()

            # 3) Compute stochasticity metric
            endo_parents_list = get_endogenous_perents(scm, var_id)
            prob_table = (
                cartesian_prod_df.groupby([var_id] + endo_parents_list)
                .size()
                .reset_index(name="joint_prob")
            )
            prob_table = pd.merge(
                prob_table,
                cartesian_prod_df.groupby(var_id)
                .size()
                .reset_index(name="marginal_prob"),
                on=var_id,
                how="left",
            )
            prob_table["joint_prob"] = (
                prob_table["joint_prob"] / prob_table["joint_prob"].sum()
            )
            prob_table["marginal_prob"] = (
                prob_table["marginal_prob"] / prob_table["marginal_prob"].sum()
            )
            entropy_list.append(
                entropy(prob_table["joint_prob"], prob_table["marginal_prob"])
            )

            # 4) Compute linearity and monotonicity metrics
            for pa in scm.parents[var_id]:
                linearity_list.append(
                    pearsonr(cartesian_prod_df[var_id], cartesian_prod_df[pa]).statistic
                )
                monotonicity_list.append(
                    spearmanr(
                        cartesian_prod_df[var_id], cartesian_prod_df[pa]
                    ).statistic
                )

            # 5) Compute surjectivity measure TODO

            # 6) Compute injectivity measure TODO

    return {
        "avg_mechanisms_entropy": statistics.mean(entropy_list),
        "var_mechanisms_entropy": (
            0
            if len(entropy_list) < 2
            else float(
                statistics.mean([d**2 for d in entropy_list])
                - statistics.mean(entropy_list) ** 2
            )
        ),
        "avg_linearity": statistics.mean(linearity_list),
        "var_linearity": (
            0
            if len(linearity_list) < 2
            else float(
                statistics.mean([d**2 for d in linearity_list])
                - statistics.mean(linearity_list) ** 2
            )
        ),
        "avg_monotonicity": statistics.mean(monotonicity_list),
        "var_monotonicity": (
            0
            if len(monotonicity_list) < 2
            else float(
                statistics.mean([d**2 for d in monotonicity_list])
                - statistics.mean(monotonicity_list) ** 2
            )
        ),
    }
