# This file provides code that implements various topological ordering baselines for ANMs.
# Methods are imported from libaries below, then a custom func for each method is written, which outputs the topological ordering (cause comes before effect).
# Input for each func is a numpy array of data (rows are observations, columns are variables), and the output is a list of indices of topological ordering.
# Note that the implemenation is not method specific per se - each method is emblematic of a general approach/class of methods. For example,

# Import libraries
import lingam as lingam
import dodiscover
from CausalDisco.analytics import r2coeff
from sklearn.ensemble import RandomForestRegressor
import networkx as nx
import numpy as np
import pandas as pd

from scipy.spatial import KDTree
from scipy.special import digamma

from plot_training_curve import additional_information
from pnl import PNL


def knn_entropy_estimator(samples, k=100, base=2):
    """
    Estimates the entropy of a continuous distribution using the k-nearest neighbors method.

    Parameters:
        samples (array-like): Empirical samples from a continuous distribution.
        k (int): Number of nearest neighbors to use.
        base (float): Logarithm base, default is 2 (for bits).

    Returns:
        float: Estimated entropy.
    """
    n = len(samples)
    tree = KDTree(
        samples.reshape(-1, 1)
    )  # Build KDTree for fast nearest-neighbor lookup
    dists, _ = tree.query(samples.reshape(-1, 1), k + 1)  # Find k-th nearest neighbor
    avg_log_dist = np.mean(np.log(dists[:, -1]))  # Take log of distances

    # Kozachenko-Leonenko entropy estimator
    entropy = digamma(n) - digamma(k) + avg_log_dist + np.log(2)  # Correction term
    return entropy / np.log(base)  # Convert to desired log base


def conditional_entropy_knn(A, B, k=100):
    """Estimate h(A | B) using k-NN method."""
    AB = np.vstack([A, B]).T  # Joint samples
    h_AB = knn_entropy_estimator(AB, k)  # Joint entropy
    h_B = knn_entropy_estimator(B.reshape(-1, 1), k)  # Marginal entropy of B
    return h_AB - h_B  # Conditional entropy h(A | B)


def entropy_knn(data):
    X = data[:, 0]
    Y = data[:, 1]
    XY = float(conditional_entropy_knn(X, Y))
    YX = float(conditional_entropy_knn(Y, X))
    return {
        "correct": int(XY < YX),
        "entropy_knn": XY - YX,
        "XY": XY,
        "YX": YX,
    }


# linear ANM (DirectLiNGAM)
def linear_ANM(data):
    return list(lingam.DirectLiNGAM().fit(pd.DataFrame(data)).causal_order_)


# additive ANM Residual Independence (CAM-UV)
def additive_ANM_UV(data):
    def topological_sort_from_matrix(matrix):
        n = len(matrix)  # Number of nodes
        G = nx.DiGraph()

        # Add edges to the graph based on the adjacency matrix
        for i in range(n):
            for j in range(n):
                if matrix[i][j] == 1:
                    G.add_edge(i, j)

        # Perform topological sort
        topo_sort = list(nx.topological_sort(G))

        return topo_sort

    adj_mat = lingam.CAMUV().fit(pd.DataFrame(data)).adjacency_matrix_
    return topological_sort_from_matrix(adj_mat)


# additive ANM MLE (CAM)


def additive_ANM(data):
    df = pd.DataFrame(data)
    cam = dodiscover.toporder.CAM(prune=False)
    context = dodiscover.make_context().variables(data=df).build()
    cam.learn_graph(df, context)
    return [df.columns[i] for i in cam.order_]


# nonlinear ANM (RESIT)
def nonlinear_ANM(data):
    resit = lingam.RESIT(RandomForestRegressor(max_depth=4))
    resit.fit(data)
    return list(resit.causal_order_)


# Gaussian SCORE-matching approach (SCORE)
def score_ANM(data):
    score = dodiscover.toporder.SCORE(prune=False)
    df = pd.DataFrame(data)
    context = dodiscover.make_context().variables(data=df).build()
    score.learn_graph(df, context)
    score_sort = [df.columns[i] for i in score.order_]
    return score_sort


# Non-Gaussian SCORE-matching Approach (NoGAM)
def nogam_ANM(data):
    nogam = dodiscover.toporder.NoGAM(n_crossval=2, prune=False)
    df = pd.DataFrame(data)
    context = dodiscover.make_context().variables(data=df).build()
    nogam.learn_graph(df, context)
    NoGAM_sort = [df.columns[i] for i in nogam.order_]
    return NoGAM_sort


# Variance Heuristic Baseline (VarSort)
def var_sort(data):
    return [
        index
        for index, _ in sorted(
            enumerate(np.var(data, axis=0)), key=lambda x: x[1], reverse=False
        )
    ]


# R^2 Heuristic Baseline (R2Sort)
def r2_sort(data):
    return [
        index
        for index, _ in sorted(
            enumerate(r2coeff(data.T)), key=lambda x: x[1], reverse=False
        )
    ]


def pnl(data):
    pnl = PNL()
    p_value_forward, p_value_backward = pnl.cause_or_effect(
        data[:, 0].reshape(-1, 1), data[:, 1].reshape(-1, 1)
    )
    correct = int(p_value_forward < p_value_backward)
    return {
        "correct": correct,
        "p_value": p_value_forward,
        "p_value_backward": p_value_backward,
    }


import sys

sys.path.append("causal-score-matching")
from causal_discovery.adascore import AdaScore


def causal_score_matching(data):
    df = pd.DataFrame(data)
    algo = AdaScore(
        alpha_orientation=0.05, alpha_confounded_leaf=0.05, alpha_separations=0.05
    )
    graph = algo.fit(df)
    has_edge_0_to_1 = graph.has_edge(0, 1)
    has_edge_1_to_0 = graph.has_edge(1, 0)
    return {"correct": int(has_edge_0_to_1 and not has_edge_1_to_0)}


from dagma.linear import DagmaLinear


def get_dagmalinear_order(X):
    """
    Computes the topological order of variables using the notears algorithm.
    Parameters:
    - X (np.ndarray or pd.DataFrame): Input data (n x d matrix).
    Returns:
    - List of variable indices in topological order if successful.
    - None if the method fails.
    """
    model = DagmaLinear(loss_type="l2")
    W_est = model.fit(
        X, lambda1=0.05, T=4, mu_init=1, s=[1, 0.9, 0.8, 0.7], mu_factor=0.1
    )
    # print(W_est)
    # print(convert_to_binary(W_est))
    # print(topological_sort_from_matrix(convert_to_binary(W_est)))
    # return topological_sort_from_matrix(convert_to_binary(W_est))
    # print(topological_sort(convert_to_binary(W_est)))
    print(W_est)
    correct = W_est[0, 1] > 0
    return {
        "correct": int(correct),
        "W_est": W_est,
    }


# Example Implemenation:


def neural_network_transform(
    parent_data: np.ndarray, num_hidden: int = 10
) -> np.ndarray:
    """
    Apply a neural network transformation to the input parent data.

    Args:
        parent_data (np.ndarray): The data from parent nodes, shape (n_samples, num_parents).
        num_hidden (int): Number of hidden units in the neural network.

    Returns:
        np.ndarray: Transformed data with shape (n_samples,).
    """
    # Initialize random weights for input to hidden layer and hidden to output layer
    # used for aistats
    weights_in = np.random.uniform(
        -5, 5, (parent_data.shape[1], num_hidden)
    )  # (num_parents, num_hidden)
    bias_hidden = np.random.uniform(-5, 5, num_hidden)  # (num_hidden,)
    weights_out = np.random.uniform(-5, 5, num_hidden)  # (num_hidden,)

    # Compute hidden layer activations using tanh
    hidden_layer = np.tanh(
        np.dot(parent_data, weights_in) + bias_hidden
    )  # (n_samples, num_hidden)

    # Compute the final output as a weighted sum of hidden activations
    output = np.dot(hidden_layer, weights_out)  # (n_samples,)

    return output


# Next, define DGMs: # First, define DGMs


# No unmeasured mediator
def dgm_pairwise(n=1000):
    x = np.random.normal(0, 1, n)
    z = neural_network_transform(x.reshape(-1, 1)) + np.random.normal(0, 1, n)
    # standardize data
    x = (x - np.mean(x)) / np.std(x)
    z = (z - np.mean(z)) / np.std(z)
    # return only input, output
    return np.array([x, z]).T


# 3 unmeausered mediators
def dgm_unmeasured_mediator(n=1000):
    x = np.random.normal(0, 1, n)
    y1 = neural_network_transform(x.reshape(-1, 1)) + np.random.normal(0, 1, n)
    y2 = neural_network_transform(y1.reshape(-1, 1)) + np.random.normal(0, 1, n)
    y3 = neural_network_transform(y2.reshape(-1, 1)) + np.random.normal(0, 1, n)
    z = neural_network_transform(y3.reshape(-1, 1)) + np.random.normal(0, 1, n)
    # standardize data
    x = (x - np.mean(x)) / np.std(x)
    y1 = (y1 - np.mean(y1)) / np.std(y1)
    y2 = (y2 - np.mean(y2)) / np.std(y2)
    y3 = (y3 - np.mean(y3)) / np.std(y3)
    z = (z - np.mean(z)) / np.std(z)
    # return only input, output
    return np.array([x, z]).T


# Exp that runs each of the methods for each dgm 30 times, and plots the results (accurate if output is [0,1], inaccurate if [1,0])
def run_exp(dgm, n=100, k=100):
    methods = [
        linear_ANM,
        additive_ANM,
        additive_ANM_UV,
        nonlinear_ANM,
        score_ANM,
        nogam_ANM,
        var_sort,
        r2_sort,
    ]

    # Dictionary to store results for each method
    results_dict = {"method": [], "accuracy": [], "variance": []}

    # Store results for each method across all trials
    method_results = {method.__name__: [] for method in methods}

    # Run k trials, generating new data each time
    for _ in range(k):
        # Generate new dataset for this trial
        data = dgm(n)

        # Run each method on this dataset
        for method in methods:
            order = method(data)
            # Check if order is correct [0,1]
            accuracy = 1 if order == [0, 1] else 0
            method_results[method.__name__].append(accuracy)

    # Calculate statistics across all trials for each method
    for method in methods:
        method_name = method.__name__
        accuracies = method_results[method_name]

        avg_accuracy = np.mean(accuracies)
        var_accuracy = np.var(accuracies)

        results_dict["method"].append(method_name)
        results_dict["accuracy"].append(avg_accuracy)
        results_dict["variance"].append(var_accuracy)

    # Create and display results table
    results_df = pd.DataFrame(results_dict)
    results_df.to_csv("experiment_results.csv", index=False)
    print(f"\n{dgm.__name__} Results Table:")
    print(results_df)


# run_exp(dgm_pairwise)
# run_exp(dgm_unmeasured_mediator)


# As you can see from table, when there is no unmeasured mediator, the nonlinear methods perform well, especially additive_ANM, nonlinear_ANM (above heuristic baselines).
# When there are unmeasured mediators, peformance of all methods drops significantly. They all perform worse than the heuristic baseline R2Sort.


# Ignore this for now
# Non-Gaussian SCORE Approach for ANM (Adascore)
