"""
Evaluation script for comparing causal graphs against ground truth.

This script calculates F1 score, Structural Hamming Distance (SHD) and
Structural Intervention Distance (SID) between generated causal graphs
and ground truth graphs.
"""

import os
import json
import numpy as np
import networkx as nx
from sklearn.metrics import precision_score, recall_score, f1_score

# Try to import causal discovery toolbox for SHD calculations
try:
    import cdt
    from cdt.metrics import SHD
    cdt_available = True
except ImportError:
    print("Causal Discovery Toolbox not found. Using manual implementations.")
    cdt_available = False


def load_ground_truth(voting_dir, idx):
    """
    Load ground truth matrix from voting_results directory.
    """
    with open(os.path.join(voting_dir, f"{idx}.json"), 'r') as f:
        matrix = json.load(f)
    return np.array(matrix)


def load_generated_graph(dag_data_dir, idx):
    """
    Load generated graph from dag_data directory.
    """
    with open(os.path.join(dag_data_dir, f"graph_{idx}.json"), 'r') as f:
        data = json.load(f)
    return np.array(data["adjacency_matrix"])


def matrix_to_networkx(matrix):
    """
    Convert an adjacency matrix to a NetworkX DiGraph.
    """
    G = nx.DiGraph()
    n = len(matrix)
    G.add_nodes_from(range(n))

    for i in range(n):
        for j in range(n):
            if matrix[i][j] == 1:
                G.add_edge(i, j)

    return G


def calculate_f1(true_matrix, pred_matrix):
    """
    Calculate F1 score between two adjacency matrices.
    """
    # Ensure both matrices are flattened to the same length
    true_flat = np.array(true_matrix).flatten()
    pred_flat = np.array(pred_matrix).flatten()

    precision = precision_score(true_flat, pred_flat, zero_division=0)
    recall = recall_score(true_flat, pred_flat, zero_division=0)
    f1 = f1_score(true_flat, pred_flat, zero_division=0)

    return {"precision": precision, "recall": recall, "f1": f1}


def calculate_shd(true_graph, pred_graph, double_for_anticausal=False):
    """
    Manual implementation of Structural Hamming Distance.
    """
    if isinstance(true_graph, nx.DiGraph):
        true_matrix = nx.to_numpy_array(true_graph)
        pred_matrix = nx.to_numpy_array(pred_graph)
    else:
        true_matrix = true_graph
        pred_matrix = pred_graph

    diff = np.abs(true_matrix - pred_matrix)
    if double_for_anticausal:
        return np.sum(diff)
    else:
        diff = diff + diff.transpose()
        diff[diff > 1] = 1  # Ignoring the double edges
        return np.sum(diff) / 2


def get_ancestors(G, node):
    """
    Get all ancestors of a node in a directed graph.
    """
    return set(nx.ancestors(G, node))


def get_parents(G, node):
    """
    Get the parents of a node in a directed graph.
    """
    return list(G.predecessors(node))


def get_descendants(G, node):
    """
    Get all descendants of a node in a directed graph.
    """
    return set(nx.descendants(G, node))


def is_valid_adjustment(G_true, i, j, adjustment_set):
    """
    Determine if adjustment_set is valid for estimating the causal effect
    from node i to node j in the true graph G_true.

    This implements the backdoor criterion: an adjustment set is valid if it
    blocks all backdoor paths from i to j (paths with an arrow into i)
    and doesn't contain any descendants of i that are on the causal path.
    """
    # If there's no edge i->j in the true graph, any adjustment works
    if not G_true.has_edge(i, j):
        return True

    # Get all ancestors of i (including i)
    ancestors_i = get_ancestors(G_true, i).union({i})

    # Get all descendants of i
    descendants_i = get_descendants(G_true, i)

    # Check if adjustment set contains any descendants of i
    if any(node in descendants_i for node in adjustment_set):
        return False

    # Create a modified graph by removing the direct edge i->j
    G_modified = G_true.copy()
    G_modified.remove_edge(i, j)

    # Check if adjustment set d-separates i and j in the modified graph
    # This is a simplified version - in practice, we'd use a proper d-separation algorithm
    return nx.d_separated(G_modified, {i}, {j}, set(adjustment_set))


def calculate_sid(G_true, G_pred):
    """
    Calculate the Structural Intervention Distance (SID) between two DAGs.

    SID counts the number of pairs (i,j) where the parent adjustment in G_pred
    is not valid for estimating the causal effect from i to j in G_true.
    """
    if not isinstance(G_true, nx.DiGraph):
        G_true = matrix_to_networkx(G_true)

    if not isinstance(G_pred, nx.DiGraph):
        G_pred = matrix_to_networkx(G_pred)

    # Check if graphs are DAGs
    if not nx.is_directed_acyclic_graph(G_true) or not nx.is_directed_acyclic_graph(G_pred):
        print("Warning: One of the graphs is not a DAG. Making it a DAG by removing cycles.")
        # Find and remove cycles if present
        if not nx.is_directed_acyclic_graph(G_true):
            cycles = list(nx.simple_cycles(G_true))
            for cycle in cycles:
                G_true.remove_edge(cycle[0], cycle[-1])

        if not nx.is_directed_acyclic_graph(G_pred):
            cycles = list(nx.simple_cycles(G_pred))
            for cycle in cycles:
                G_pred.remove_edge(cycle[0], cycle[-1])

    nodes = list(G_true.nodes())
    n_nodes = len(nodes)
    sid_count = 0

    for i in nodes:
        for j in nodes:
            if i != j:
                # Get parents of i in predicted graph to use as adjustment set
                parents_i_pred = get_parents(G_pred, i)

                # Check if this adjustment set is valid in the true graph
                if not is_valid_adjustment(G_true, i, j, parents_i_pred):
                    sid_count += 1

    return sid_count


def evaluate_graphs(ground_truth_dir="ground_truth/voting_results",
                   dag_data_dir="../data_generation/iTAG/dag_data",
                   num_graphs=99):
    """
    Evaluate the generated graphs against ground truth.
    """
    results = {
        "f1_scores": [],
        "precision_scores": [],
        "recall_scores": [],
        "shd_scores": [],
        "sid_scores": []
    }

    valid_graphs = 0
    for idx in range(num_graphs):
        try:
            gt_path = os.path.join(ground_truth_dir, f"{idx}.json")
            gen_path = os.path.join(dag_data_dir, f"graph_{idx}.json")

            if not os.path.exists(gt_path) or not os.path.exists(gen_path):
                print(f"Skipping graph {idx}: files not found")
                continue

            true_matrix = load_ground_truth(ground_truth_dir, idx)
            pred_matrix = load_generated_graph(dag_data_dir, idx)

            # Ensure matrices have the same shape
            n_true = len(true_matrix)
            n_pred = len(pred_matrix)

            if n_true != n_pred:
                print(f"Graph {idx} has different dimensions (GT: {n_true}, Pred: {n_pred})")
                # Use the smaller dimension
                n = min(n_true, n_pred)
                true_matrix = np.array(true_matrix)[:n, :n]
                pred_matrix = np.array(pred_matrix)[:n, :n]

            # Calculate F1 score
            f1_metrics = calculate_f1(true_matrix, pred_matrix)
            results["f1_scores"].append(f1_metrics["f1"])
            results["precision_scores"].append(f1_metrics["precision"])
            results["recall_scores"].append(f1_metrics["recall"])

            # Calculate SHD
            if cdt_available:
                true_graph = matrix_to_networkx(true_matrix)
                pred_graph = matrix_to_networkx(pred_matrix)
                shd_score = SHD(true_graph, pred_graph, double_for_anticausal=False)
            else:
                shd_score = calculate_shd(true_matrix, pred_matrix)

            results["shd_scores"].append(shd_score)

            # Calculate SID
            true_graph = matrix_to_networkx(true_matrix)
            pred_graph = matrix_to_networkx(pred_matrix)

            try:
                sid_score = calculate_sid(true_graph, pred_graph)
                results["sid_scores"].append(sid_score)
            except Exception as e:
                print(f"Error calculating SID for graph {idx}: {e}")
                sid_score = None
                results["sid_scores"].append(None)

            print(f"Graph {idx} - F1: {f1_metrics['f1']:.4f}, SHD: {shd_score}, SID: {sid_score}")
            valid_graphs += 1

        except Exception as e:
            print(f"Error processing graph {idx}: {e}")

    if valid_graphs == 0:
        print("No valid graph comparisons were made. Please check file paths.")
        return

    avg_f1 = np.mean(results["f1_scores"])
    avg_precision = np.mean(results["precision_scores"])
    avg_recall = np.mean(results["recall_scores"])
    avg_shd = np.mean(results["shd_scores"])

    # Calculate average SID only for valid values
    sid_values = [s for s in results["sid_scores"] if s is not None]
    avg_sid = np.mean(sid_values) if sid_values else None

    print("\nAverage Metrics:")
    print(f"F1 Score: {avg_f1:.4f}")
    print(f"Precision: {avg_precision:.4f}")
    print(f"Recall: {avg_recall:.4f}")
    print(f"SHD: {avg_shd:.4f}")
    if avg_sid is not None:
        print(f"SID: {avg_sid:.4f}")

    print(f"Successfully processed {valid_graphs} graphs")


def main():
    """Main function to run the evaluation"""
    base_dir = '.'
    ground_truth_dir = os.path.join(base_dir, "ground_truth", "voting_results")
    dag_data_dir = os.path.join(base_dir, "..", "data_generation", "iTAG", "dag_data")

    # Number of graphs to evaluate
    num_graphs = 1000

    print(f"Evaluating {num_graphs} graphs...")
    print(f"Ground truth directory: {ground_truth_dir}")
    print(f"Generated graphs directory: {dag_data_dir}")

    # Run evaluation
    evaluate_graphs(ground_truth_dir, dag_data_dir, num_graphs)


if __name__ == "__main__":
    main()