import statistics
import networkx as nx
import pandas as pd
import numpy as np
from scipy.stats import entropy, pearsonr, spearmanr
from sklearn.neighbors import KernelDensity

from causal_profiler.scm import SCM
from causal_profiler.constants import VariableDataType


def get_endogenous_perents(scm: SCM, var_id: str):
    """Get the list of endogenous parents of the variable var_id in the SCM scm

    Args:
        scm (SCM): the SCM of interest
        var_id (str): the name of the variable of interest

    Returns:
        list: the list of the endogenous parents' names
    """
    parents_list = []
    for pa in scm.parents[var_id]:
        if not scm.variables[pa].exogenous:
            parents_list.append(pa)
    return parents_list


def compute_marginal_distribution(samples: pd.DataFrame, var_id: str):
    """Compute the marginal distribution of the variable var_id following the samples in the dataframe samples.

    Args:
        samples (pd.DataFrame): the samples to compute the marginal distribution
        var_id (str): the varaible of interest

    Returns:
        pd.DataFrame: the marginal distribution of the variable. The column var_id stores the varaible values and
                      the column 'proba' stores the probability.
    """
    marginal_distrib = samples.groupby(var_id).size().reset_index(name="proba")
    marginal_distrib["proba"] = (
        marginal_distrib["proba"] / marginal_distrib["proba"].sum()
    )
    return marginal_distrib


def get_variables_discrete_values(scm: SCM, var_id: str):
    """Computes the domaine of definition of the variable var_id in the SCM scm.

    Args:
        scm (SCM): the SCM of interest
        var_id (str): the name of the variable of interest

    Returns:
        list: the list of discrete values on which the variable is defined according to the SCM
    """
    var = scm.variables[var_id]
    if not var.exogenous:
        var_discrete_values_list = [i for i in range(var.num_discrete_values)]
    else:
        var_discrete_values_list = []
        inf = 0
        for noise_region in var.noise_regions:
            var_discrete_values_list.append((inf + noise_region) / 2)
        var_discrete_values_list.append((var.noise_regions[-1] + 1) / 2)
    return var_discrete_values_list


def get_DAG_metrics(scm: SCM):
    """
    Computes several graph-based metrics on the input SCM causal graph (DAG over the observed and unobserved varaibles):
      - average and variance of degrees (undirected-degree, in-degree and out-degree)
      - average and variance number of ancestors
      - average and variance number of descendants
      - average, variance and max of causal path lengths

    Returns a dictionary with these metrics.
    """
    # 1) Build the SCM's DAG over the observed and unobserved endogeneous varaibles
    G = nx.DiGraph()
    for var_id in scm.variables:
        if not scm.variables[var_id].exogenous:
            G.add_node(var_id)
    for parent_id, child_ids in scm.children.items():
        if not scm.variables[parent_id].exogenous:
            for child_id in child_ids:
                if not scm.variables[child_id].exogenous:
                    G.add_edge(parent_id, child_id)

    # 2) Degree analysis
    und_degrees = [val for _, val in G.degree()]
    in_degrees = [val for _, val in G.in_degree()]
    out_degrees = [val for _, val in G.out_degree()]
    avg_und_degrees = statistics.mean(und_degrees)
    avg_in_degrees = statistics.mean(in_degrees)
    avg_out_degrees = statistics.mean(out_degrees)
    variance_of_und_degrees = (
        0
        if len(und_degrees) < 2
        else float(
            statistics.mean([d**2 for d in und_degrees])
            - statistics.mean(und_degrees) ** 2
        )
    )
    variance_of_in_degrees = (
        0
        if len(in_degrees) < 2
        else float(
            statistics.mean([d**2 for d in in_degrees])
            - statistics.mean(in_degrees) ** 2
        )
    )
    variance_of_out_degrees = (
        0
        if len(out_degrees) < 2
        else float(
            statistics.mean([d**2 for d in out_degrees])
            - statistics.mean(out_degrees) ** 2
        )
    )

    # 3) Ancestors / Descendants
    num_ancestors_list = []
    num_descendants_list = []
    for node in G.nodes():
        anc = nx.ancestors(G, node)
        desc = nx.descendants(G, node)
        num_ancestors_list.append(len(anc))
        num_descendants_list.append(len(desc))
    avg_num_ancestors = float(sum(num_ancestors_list)) / max(len(num_ancestors_list), 1)
    var_num_ancestors = (
        0
        if len(num_ancestors_list) < 2
        else float(
            statistics.mean([x**2 for x in num_ancestors_list])
            - statistics.mean(num_ancestors_list) ** 2
        )
    )
    avg_num_descendants = float(sum(num_descendants_list)) / max(
        len(num_descendants_list), 1
    )
    var_num_descendants = (
        0
        if len(num_descendants_list) < 2
        else float(
            statistics.mean([x**2 for x in num_descendants_list])
            - statistics.mean(num_descendants_list) ** 2
        )
    )

    # 4) Causal path length
    causal_path_lengths = []
    for node in G.nodes():
        # BFS or single-source shortest paths
        sp = nx.single_source_shortest_path_length(G, node)
        # exclude the node itself
        for target, dist in sp.items():
            if target != node:
                causal_path_lengths.append(dist)
    avg_causal_path_len = (
        float(sum(causal_path_lengths)) / max(len(causal_path_lengths), 1)
        if causal_path_lengths
        else 0.0
    )
    var_causal_path_len = (
        0
        if len(causal_path_lengths) < 2
        else float(
            statistics.mean([x**2 for x in causal_path_lengths])
            - statistics.mean(causal_path_lengths) ** 2
        )
    )
    try:
        longest_causal_path_len = nx.dag_longest_path_length(G)
    except nx.NetworkXUnfeasible:
        # If the graph has cycles, dag_longest_path_length() will fail
        longest_causal_path_len = -1

    return {
        "avg_und_degrees": avg_und_degrees,
        "variance_of_und_degrees": variance_of_und_degrees,
        "avg_in_degrees": avg_in_degrees,
        "variance_of_in_degrees": variance_of_in_degrees,
        "avg_out_degrees": avg_out_degrees,
        "variance_of_out_degrees": variance_of_out_degrees,
        "avg_num_ancestors": avg_num_ancestors,
        "var_num_ancestors": var_num_ancestors,
        "avg_num_descendants": avg_num_descendants,
        "var_num_descendants": var_num_descendants,
        "longest_causal_path_length": longest_causal_path_len,
        "avg_causal_path_len": avg_causal_path_len,
        "var_causal_path_len": var_causal_path_len,
    }


def get_ADMG_metrics(scm):
    """
    Computes several graph-based metrics on the input SCM causal graph projected on the observed variables only (ADMG):
      - number of maximal confounded components (called c-components)
      - average, variance and max size of the maximal c-components
      - average and variance number of siblings (two varaibles are siblings if they are linked by a bidirected edge)

    Returns a dictionary with these metrics.
    """
    # 1) Build a NetworkX DiGraph with hidden confounders from the SCM's 'adjacency_list' projected on the observed variables
    G = nx.DiGraph()
    adjacency_dict, _ = scm.extract_adjacency_list()
    if len(adjacency_dict) < 2:
        return {
            "num_c_components": 0,
            "largest_c_comp_size": 0,
            "avg_c_comp_size": 0,
            "var_c_comp_size": 0,
            "avg_num_siblings": 0,
            "var_num_siblings": 0,
        }
    else:
        nodes_list = []
        hidden_confounders_list = []
        hidden_confounder_index = 0
        for parent_id, child_ids in adjacency_dict.items():
            nodes_list.append(parent_id)
            G.add_node(parent_id)
            for child_id in child_ids:
                if G.has_edge(
                    child_id, parent_id
                ):  # there is a bidirected edge between child_id and parent_id
                    # remove existing edge
                    G.remove_edge(child_id, parent_id)
                    # add hidden confounder
                    hidden_conf = "U_" + str(hidden_confounder_index)
                    hidden_confounders_list.append(hidden_conf)
                    G.add_edge(hidden_conf, child_id)
                    G.add_edge(hidden_conf, parent_id)
                    hidden_confounder_index += 1
                else:
                    G.add_edge(parent_id, child_id)

        # 2) Find the maximal c-components
        # Create an undirected subgraph of G, called G_ccomp, containing only the hidden confounders' edges
        G_ccomp = nx.Graph()
        for hidden_conf in hidden_confounders_list:
            for child1 in G.successors(hidden_conf):
                for child2 in G.successors(hidden_conf):
                    if child1 != child2:
                        G_ccomp.add_edge(child1, child2)
        if G_ccomp.number_of_edges() < 1:
            return {
                "num_c_components": 0,
                "largest_c_comp_size": 0,
                "avg_c_comp_size": 0,
                "var_c_comp_size": 0,
                "avg_num_siblings": 0,
                "var_num_siblings": 0,
            }
        else:
            # Find the maximal complete subgraphs of G_ccomp
            ccomps_list = list(nx.find_cliques(G_ccomp))
            ccomps_sizes_list = [len(ccomp) for ccomp in ccomps_list]
            # Compute metrics
            num_c_components = len(ccomps_list)
            largest_c_comp_size = (
                max(ccomps_sizes_list) if len(ccomps_sizes_list) > 0 else 0
            )
            avg_c_comp_size = float(sum(ccomps_sizes_list)) / max(
                len(ccomps_sizes_list), 1
            )
            var_c_comp_size = (
                0
                if len(ccomps_sizes_list) < 2
                else float(
                    statistics.mean([x**2 for x in ccomps_sizes_list])
                    - statistics.mean(ccomps_sizes_list) ** 2
                )
            )

            # 1) Siblings: share at least one hidden confounder
            sibling_counts = []
            for node in nodes_list:
                # Find parents of 'node'
                parents_list = G.predecessors(node)
                hidden_conf_parents_list = []
                # Keep only hidden confounders
                for parent in parents_list:
                    if parent in hidden_confounders_list:
                        hidden_conf_parents_list.append(parent)
                # Count siblings
                count_siblings = 0
                for other in nodes_list:
                    if other == node:
                        continue
                    other_parents = set(G.predecessors(other))
                    if (
                        set(hidden_conf_parents_list) & other_parents
                    ):  # non-empty intersection
                        count_siblings += 1
                sibling_counts.append(count_siblings)
            avg_num_siblings = float(sum(sibling_counts)) / max(len(sibling_counts), 1)
            var_num_siblings = (
                0
                if len(sibling_counts) < 2
                else float(
                    statistics.mean([x**2 for x in sibling_counts])
                    - statistics.mean(sibling_counts) ** 2
                )
            )

            return {
                "num_c_components": num_c_components,
                "largest_c_comp_size": largest_c_comp_size,
                "avg_c_comp_size": avg_c_comp_size,
                "var_c_comp_size": var_c_comp_size,
                "avg_num_siblings": avg_num_siblings,
                "var_num_siblings": var_num_siblings,
            }


def get_scm_type(scm: SCM):
    """
    Test wehther the SCM under study is fully discrete or fully continuous (ie we do not treat mixed SCMs).

    Returns a string specifying the scm type: "continuous" or "discrete"
    """

    nb_continuous_vars = 0
    nb_discrete_vars = 0
    nb_non_typed_vars = 0
    for var_id in scm.variables:
        if not scm.variables[var_id].exogenous:
            if scm.variables[var_id].variable_type == VariableDataType.CONTINUOUS:
                nb_continuous_vars += 1
            elif scm.variables[var_id].variable_type == VariableDataType.DISCRETE:
                nb_discrete_vars += 1
            else:
                nb_non_typed_vars += 1

    # Assert that the SCM is fully discrete or fully continuous
    assert (
        nb_non_typed_vars == 0
    ), f"The SCM contains {nb_non_typed_vars} non continuous or discrete variables"
    assert (
        nb_continuous_vars == 0 or nb_discrete_vars == 0
    ), f"The SCM contains continuous and discrete variables"

    # Retrun SCM type
    if nb_continuous_vars > 0:
        return "continuous"
    else:
        return "discrete"


def get_distrib_metrics(
    scm: SCM,
    n_samples_for_computation: int = 1000000,
    sampled_data_df: pd.DataFrame = None,
):
    """
    Computes several metrics on the observational distribution of the SCM:
      - level of strong positivity (i.e., the min value of the joint probability and the proportion of null proba events)
      - average and variance level of weak positivity (i.e., the min value of the marginal probability of each variable)
      - average and variance level of maginal imbalanced distribution (i.e., L1 distance to uniform distribution)
      - level of joint imbalanced distribution (i.e., L1 distance to uniform distribution)
      - average and variance of the conditional entropy of the variables distributions given their parents
      - entropy of the joint distribution

    Returns a dictionary with these metrics.
    """

    scm_type = get_scm_type(scm)

    weak_pos_list = []
    marginal_imbalanced_list = []
    entropy_list = []

    # 1) Sample observational data
    if sampled_data_df is None:
        sampled_data = scm.sample_data(n_samples_for_computation)
        endo_vars_list = []
        for var_id in sampled_data.keys():
            flattened_sampled_data = sampled_data[var_id].flatten()
            sampled_data[var_id] = flattened_sampled_data
            if not scm.variables[var_id].exogenous:
                endo_vars_list.append(var_id)
        sampled_data_df = pd.DataFrame.from_dict(sampled_data)[endo_vars_list]
    else:
        endo_vars_list = []
        for var_id in scm.variables:
            if not scm.variables[var_id].exogenous:
                endo_vars_list.append(var_id)

    # 2) Compute conditional stochasticity metric
    for var_id in scm.variables:
        if not scm.variables[var_id].exogenous:
            endo_parents_list = get_endogenous_perents(scm, var_id)
            if len(endo_parents_list) > 0:
                if scm_type == "continuous":  # use KDE to compute entropy
                    samples_table = sampled_data_df[[var_id] + endo_parents_list]
                    # Estimate Kernel Density
                    kde_x = KernelDensity(kernel="gaussian", bandwidth="silverman").fit(
                        samples_table
                    )
                    kde_pa = KernelDensity(
                        kernel="gaussian", bandwidth="silverman"
                    ).fit(samples_table[endo_parents_list])
                    # Compute density of samples
                    p_x = kde_x.score_samples(samples_table)
                    q_pa = kde_pa.score_samples(samples_table[endo_parents_list])
                    entropy_list.append(entropy(p_x, q_pa))
                else:  # use frequencies to compute entropy
                    prob_table = (
                        sampled_data_df.groupby([var_id] + endo_parents_list)
                        .size()
                        .reset_index(name="joint_prob")
                    )
                    prob_table = pd.merge(
                        prob_table,
                        sampled_data_df.groupby(endo_parents_list)
                        .size()
                        .reset_index(name="marginal_prob"),
                        on=endo_parents_list,
                        how="left",
                    )
                    prob_table["joint_prob"] = (
                        prob_table["joint_prob"] / prob_table["joint_prob"].sum()
                    )
                    prob_table["marginal_prob"] = (
                        prob_table["marginal_prob"] / prob_table["marginal_prob"].sum()
                    )
                    entropy_list.append(
                        entropy(prob_table["joint_prob"], prob_table["marginal_prob"])
                    )

    if scm_type == "continuous":
        # For continuous SCMs we consider strong and weak positivity assumptions impossible to check with finite data.
        # Indeed it would require additional assumptions notably about the support and the level of smoothness of the distribution.
        # As a result, for continuous SCMs we consider strong and weak positivity always violated.
        strong_pos_measure_min = 0
        strong_pos_measure_prop = 1
        avg_weak_pos = 0
        var_weak_pos = 0

        # 3) Compute joint imbalanced distribution
        kde_joint = KernelDensity(kernel="gaussian", bandwidth="silverman").fit(
            sampled_data_df[endo_vars_list]
        )
        p_joint = kde_joint.score_samples(sampled_data_df[endo_vars_list])
        joint_imbalanced_measure = np.mean(np.abs(p_joint - 1 / p_joint.shape[0]))
        total_entropy = entropy(p_joint)

        # 4) Compute marginal imbalanced distribution
        for var_id in endo_vars_list:
            kde_marginal = KernelDensity(kernel="gaussian", bandwidth=1).fit(
                pd.DataFrame({var_id: sampled_data_df[var_id].values})
            )
            p_marginal = kde_marginal.score_samples(
                pd.DataFrame({var_id: sampled_data_df[var_id].values})
            )
            marginal_imbalanced_list.append(
                np.mean(np.abs(p_marginal - 1 / p_marginal.shape[0]))
            )

    else:
        # 3) Compute the joint probability table
        joint_proba_table = (
            sampled_data_df.groupby(endo_vars_list)
            .size()
            .reset_index(name="joint_prob")
        )
        joint_proba_table["joint_prob"] = (
            joint_proba_table["joint_prob"] / joint_proba_table["joint_prob"].sum()
        )
        # Compute the cartesian product of realizations
        marginal_distrib = compute_marginal_distribution(
            sampled_data_df, endo_vars_list[0]
        )
        # while computing the cartesian product, we can compute the marginal distributions
        weak_pos_list.append(marginal_distrib["proba"].min())
        marginal_imbalanced_list.append(
            np.mean(
                np.abs(marginal_distrib["proba"].values - 1 / marginal_distrib.shape[0])
            )
        )
        cartesian_prod_table = marginal_distrib.drop(columns=["proba"])
        for var_id in endo_vars_list[1:]:
            # while computing the cartesian product, we can compute the marginal distributions
            marginal_distrib = compute_marginal_distribution(sampled_data_df, var_id)
            weak_pos_list.append(marginal_distrib["proba"].min())
            marginal_imbalanced_list.append(
                np.mean(
                    np.abs(
                        marginal_distrib["proba"].values - 1 / marginal_distrib.shape[0]
                    )
                )
            )
            # update cartesian product
            cartesian_prod_table = cartesian_prod_table.merge(
                marginal_distrib.drop(columns=["proba"]), how="cross"
            )
        # Merge the cartesian product to the proba table
        full_joint_proba_table = cartesian_prod_table.merge(
            joint_proba_table, how="left"
        ).fillna(0)

        avg_weak_pos = statistics.mean(weak_pos_list)
        var_weak_pos = (
            0
            if len(weak_pos_list) < 2
            else float(
                statistics.mean([d**2 for d in weak_pos_list])
                - statistics.mean(weak_pos_list) ** 2
            )
        )
        total_entropy = entropy(joint_proba_table["joint_prob"])

        # 4) Compute strong positivity and joint imbalanced distribution
        strong_pos_measure_min = full_joint_proba_table["joint_prob"].min()
        strong_pos_measure_prop = (
            1
            - len(full_joint_proba_table[full_joint_proba_table["joint_prob"] == 0])
            / full_joint_proba_table.shape[0]
        )
        joint_imbalanced_measure = np.mean(
            np.abs(
                full_joint_proba_table["joint_prob"].values
                - 1 / full_joint_proba_table.shape[0]
            )
        )

    return {
        "strong_pos_measure_min": strong_pos_measure_min,
        "strong_pos_measure_prop": strong_pos_measure_prop,
        "avg_weak_pos": avg_weak_pos,
        "var_weak_pos": var_weak_pos,
        "avg_marginal_imbalanced": statistics.mean(marginal_imbalanced_list),
        "var_marginal_imbalanced": (
            0
            if len(marginal_imbalanced_list) < 2
            else float(
                statistics.mean([d**2 for d in marginal_imbalanced_list])
                - statistics.mean(marginal_imbalanced_list) ** 2
            )
        ),
        "joint_imbalanced_measure": joint_imbalanced_measure,
        "avg_conditional_entropy": (
            statistics.mean(entropy_list) if len(entropy_list) > 0 else np.nan
        ),
        "var_conditional_entropy": (
            0
            if len(entropy_list) < 2
            else float(
                statistics.mean([d**2 for d in entropy_list])
                - statistics.mean(entropy_list) ** 2
            )
        ),
        "total_entropy": total_entropy,
    }


def get_mechanism_metrics(scm: SCM, nb_samples_for_discretization: int = 10):
    """
    Computes several metrics on the causal mechanisms of the SCM:
      - average and variance level of stochasticity (i.e., the conditional entropy of a child given its endogenous parents)
      - average and variance level of linearity (i.e., pearson correlation between each parent-child pair
        including endogenous and exogenous variables)
      - average and variance level of monotonicity (i.e., spearman correlation between each parent-child pair
        including endogenous and exogenous variables)

    Returns a dictionary with these metrics.
    """

    scm_type = get_scm_type(scm)

    # If continuous, sample realizations of the scm to definie discrete domain
    if scm_type == "continuous":
        var_distinct_values_dict = scm.sample_data(
            total_samples=nb_samples_for_discretization
        )

    # Iterate over the endogenous variables
    entropy_list = []
    linearity_list = []
    monotonicity_list = []
    for var_id in scm.variables:
        if not scm.variables[var_id].exogenous:

            # 1) Set parents values to the cartesian product of parents ranges in order to analyse the mechanisms independently of the parents distribtion.
            #    For continuous variables, the cartesian product is build via discredization defined by samples randomly drawn from the observational distribution of the SCM.
            input_values_dict = {}
            if scm_type == "continuous":
                input_values_dict[scm.parents[var_id][0]] = np.squeeze(
                    var_distinct_values_dict[scm.parents[var_id][0]]
                )
            else:
                input_values_dict[scm.parents[var_id][0]] = (
                    get_variables_discrete_values(scm, scm.parents[var_id][0])
                )
            cartesian_prod_df = pd.DataFrame.from_dict(input_values_dict)
            for pa_id in scm.parents[var_id][1:]:
                input_values_dict = {}
                if scm_type == "continuous":
                    input_values_dict[pa_id] = np.squeeze(
                        var_distinct_values_dict[pa_id]
                    )
                else:
                    input_values_dict[pa_id] = get_variables_discrete_values(scm, pa_id)
                cartesian_prod_df = cartesian_prod_df.merge(
                    pd.DataFrame.from_dict(input_values_dict), how="cross"
                )
            # Do it by batches because the cartesian product can be very big
            batchsize = 100
            var_id_values = []
            for i in range(0, cartesian_prod_df.shape[0], batchsize):
                index_rows = [
                    j for j in range(i, min(i + batchsize, cartesian_prod_df.shape[0]))
                ]
                cartesian_prod_df_batch = cartesian_prod_df.iloc[index_rows, :].copy()
                for pa_id in scm.parents[var_id]:
                    scm.variables[pa_id].value = np.expand_dims(
                        cartesian_prod_df_batch[pa_id].values, axis=1
                    )
                # 2) Apply mechanisms
                mechanism = scm.mechanisms.get(var_id)
                var_id_values.append(mechanism())
            # add var_id values to cartesian_prod_df
            cartesian_prod_df[var_id] = np.concatenate(var_id_values)

            # 3) Compute stochasticity metric
            endo_parents_list = get_endogenous_perents(scm, var_id)
            if len(endo_parents_list) > 0:
                if scm_type == "continuous":  # Use KDE to compute entropy
                    samples_table = cartesian_prod_df[[var_id] + endo_parents_list]
                    # Estimate Kernel Density
                    kde_x = KernelDensity(kernel="gaussian", bandwidth="silverman").fit(
                        samples_table
                    )
                    kde_pa = KernelDensity(
                        kernel="gaussian", bandwidth="silverman"
                    ).fit(samples_table[endo_parents_list])
                    # Compute density of samples
                    p_x = kde_x.score_samples(samples_table)
                    q_pa = kde_pa.score_samples(samples_table[endo_parents_list])
                    entropy_list.append(entropy(p_x, q_pa))
                else:  # Use frequencies to compute entropy
                    prob_table = (
                        cartesian_prod_df.groupby([var_id] + endo_parents_list)
                        .size()
                        .reset_index(name="joint_prob")
                    )
                    prob_table = pd.merge(
                        prob_table,
                        cartesian_prod_df.groupby(endo_parents_list)
                        .size()
                        .reset_index(name="marginal_prob"),
                        on=endo_parents_list,
                        how="left",
                    )
                    prob_table["joint_prob"] = (
                        prob_table["joint_prob"] / prob_table["joint_prob"].sum()
                    )
                    prob_table["marginal_prob"] = (
                        prob_table["marginal_prob"] / prob_table["marginal_prob"].sum()
                    )
                    entropy_list.append(
                        entropy(prob_table["joint_prob"], prob_table["marginal_prob"])
                    )

            # 4) Compute linearity and monotonicity metrics
            for pa in scm.parents[var_id]:
                linearity_list.append(
                    pearsonr(cartesian_prod_df[var_id], cartesian_prod_df[pa]).statistic
                )
                monotonicity_list.append(
                    spearmanr(
                        cartesian_prod_df[var_id], cartesian_prod_df[pa]
                    ).statistic
                )

    return {
        "avg_mechanisms_entropy": (
            statistics.mean(entropy_list) if len(entropy_list) > 0 else np.nan
        ),
        "var_mechanisms_entropy": (
            0
            if len(entropy_list) < 2
            else float(
                statistics.mean([d**2 for d in entropy_list])
                - statistics.mean(entropy_list) ** 2
            )
        ),
        "avg_linearity": statistics.mean(linearity_list),
        "var_linearity": (
            0
            if len(linearity_list) < 2
            else float(
                statistics.mean([d**2 for d in linearity_list])
                - statistics.mean(linearity_list) ** 2
            )
        ),
        "avg_monotonicity": statistics.mean(monotonicity_list),
        "var_monotonicity": (
            0
            if len(monotonicity_list) < 2
            else float(
                statistics.mean([d**2 for d in monotonicity_list])
                - statistics.mean(monotonicity_list) ** 2
            )
        ),
    }
