"""
Evaluation script for comparing causal graphs against ground truth.

This script calculates F1 score, Structural Hamming Distance (SHD) and
Structural Intervention Distance (SID) between generated causal graphs
and ground truth graphs.
"""

import os
import json
import re

import numpy as np
import networkx as nx
from sklearn.metrics import precision_score, recall_score, f1_score

# Try to import causal discovery toolbox for SHD calculations
try:
    import cdt
    from cdt.metrics import SHD
    cdt_available = True
except ImportError:
    print("Causal Discovery Toolbox not found. Using manual implementations.")
    cdt_available = False


def load_ground_truth(voting_dir, idx):
    """
    Load ground truth matrix from voting_results directory.
    """
    with open(os.path.join(voting_dir, f"{idx}.json"), 'r') as f:
        matrix = json.load(f)
    return np.array(matrix)


def load_generated_graph(llm_results_dir, idx, original_file_map=None):
    """
    Load generated graph from LLM results directory.

    Args:
        llm_results_dir: Directory containing LLM inference results
        idx: Index of the graph
        original_file_map: Optional mapping from index to original filename
    """
    if original_file_map and idx in original_file_map:
        # Use the mapped filename if available
        filename = original_file_map[idx]
        filepath = os.path.join(llm_results_dir, f"{filename}.json")
    else:
        # Try with default naming pattern
        filepath = os.path.join(llm_results_dir, f"{idx}.json")

        # If not found, try to find a file that might match
        if not os.path.exists(filepath):
            print(f"File {filepath} not found, searching for matching file...")
            potential_files = [f for f in os.listdir(llm_results_dir) if f.endswith('.json')]

            # Try format like "text_{idx}.json" or similar patterns
            matching_files = [f for f in potential_files if f.startswith(f"text_{idx}.") or
                              f.startswith(f"doc_{idx}.") or
                              f.startswith(f"file_{idx}.")]

            if matching_files:
                filepath = os.path.join(llm_results_dir, matching_files[0])
                print(f"Found matching file: {filepath}")
            else:
                raise FileNotFoundError(f"Could not find a matching file for index {idx}")

    with open(filepath, 'r') as f:
        data = json.load(f)

    # Handle different JSON structure formats
    # Our LLM output format has "adjacency matrix" with a space
    if "adjacency matrix" in data:
        return np.array(data["adjacency matrix"])
    # Original format has "adjacency_matrix" without a space
    elif "adjacency_matrix" in data:
        return np.array(data["adjacency_matrix"])
    # Fall back to checking for any key that might contain an adjacency matrix
    else:
        for key in data:
            if isinstance(data[key], list) and all(isinstance(row, list) for row in data[key]):
                print(f"Using key '{key}' as adjacency matrix")
                return np.array(data[key])

        raise ValueError(f"Could not find adjacency matrix in file {filepath}")


def matrix_to_networkx(matrix):
    """
    Convert an adjacency matrix to a NetworkX DiGraph.
    """
    G = nx.DiGraph()
    n = len(matrix)
    G.add_nodes_from(range(n))

    for i in range(n):
        for j in range(n):
            if matrix[i][j] == 1:
                G.add_edge(i, j)

    return G


def calculate_f1(true_matrix, pred_matrix):
    """
    Calculate F1 score between two adjacency matrices.
    """
    # Ensure both matrices are flattened to the same length
    true_flat = np.array(true_matrix).flatten()
    pred_flat = np.array(pred_matrix).flatten()

    precision = precision_score(true_flat, pred_flat, zero_division=0)
    recall = recall_score(true_flat, pred_flat, zero_division=0)
    f1 = f1_score(true_flat, pred_flat, zero_division=0)

    return {"precision": precision, "recall": recall, "f1": f1}


def calculate_shd(true_graph, pred_graph, double_for_anticausal=False):
    """
    Manual implementation of Structural Hamming Distance.
    """
    if isinstance(true_graph, nx.DiGraph):
        true_matrix = nx.to_numpy_array(true_graph)
        pred_matrix = nx.to_numpy_array(pred_graph)
    else:
        true_matrix = true_graph
        pred_matrix = pred_graph

    diff = np.abs(true_matrix - pred_matrix)
    if double_for_anticausal:
        return np.sum(diff)
    else:
        diff = diff + diff.transpose()
        diff[diff > 1] = 1  # Ignoring the double edges
        return np.sum(diff) / 2


def get_ancestors(G, node):
    """
    Get all ancestors of a node in a directed graph.
    """
    return set(nx.ancestors(G, node))


def get_parents(G, node):
    """
    Get the parents of a node in a directed graph.
    """
    return list(G.predecessors(node))


def get_descendants(G, node):
    """
    Get all descendants of a node in a directed graph.
    """
    return set(nx.descendants(G, node))


def is_valid_adjustment(G_true, i, j, adjustment_set):
    """
    Determine if adjustment_set is valid for estimating the causal effect
    from node i to node j in the true graph G_true.

    This implements the backdoor criterion: an adjustment set is valid if it
    blocks all backdoor paths from i to j (paths with an arrow into i)
    and doesn't contain any descendants of i that are on the causal path.
    """
    # If there's no edge i->j in the true graph, any adjustment works
    if not G_true.has_edge(i, j):
        return True

    # Get all ancestors of i (including i)
    ancestors_i = get_ancestors(G_true, i).union({i})

    # Get all descendants of i
    descendants_i = get_descendants(G_true, i)

    # Check if adjustment set contains any descendants of i
    if any(node in descendants_i for node in adjustment_set):
        return False

    # Create a modified graph by removing the direct edge i->j
    G_modified = G_true.copy()
    G_modified.remove_edge(i, j)

    # Check if adjustment set d-separates i and j in the modified graph
    # This is a simplified version - in practice, we'd use a proper d-separation algorithm
    return nx.d_separated(G_modified, {i}, {j}, set(adjustment_set))


def calculate_sid(G_true, G_pred):
    """
    Calculate the Structural Intervention Distance (SID) between two DAGs.

    SID counts the number of pairs (i,j) where the parent adjustment in G_pred
    is not valid for estimating the causal effect from i to j in G_true.
    """
    if not isinstance(G_true, nx.DiGraph):
        G_true = matrix_to_networkx(G_true)

    if not isinstance(G_pred, nx.DiGraph):
        G_pred = matrix_to_networkx(G_pred)

    # Check if graphs are DAGs
    if not nx.is_directed_acyclic_graph(G_true) or not nx.is_directed_acyclic_graph(G_pred):
        print("Warning: One of the graphs is not a DAG. Making it a DAG by removing cycles.")
        # Find and remove cycles if present
        if not nx.is_directed_acyclic_graph(G_true):
            cycles = list(nx.simple_cycles(G_true))
            for cycle in cycles:
                G_true.remove_edge(cycle[0], cycle[-1])

        if not nx.is_directed_acyclic_graph(G_pred):
            cycles = list(nx.simple_cycles(G_pred))
            for cycle in cycles:
                G_pred.remove_edge(cycle[0], cycle[-1])

    nodes = list(G_true.nodes())
    n_nodes = len(nodes)
    sid_count = 0

    for i in nodes:
        for j in nodes:
            if i != j:
                # Get parents of i in predicted graph to use as adjustment set
                parents_i_pred = get_parents(G_pred, i)

                # Check if this adjustment set is valid in the true graph
                if not is_valid_adjustment(G_true, i, j, parents_i_pred):
                    sid_count += 1

    return sid_count


def build_index_to_filename_map(datasets_dir, llm_results_dir):
    """
    Build a mapping from index to filename by scanning the datasets directory
    and the results directory.

    This helps match ground truth indices with the files processed by the LLM.
    """
    index_map = {}

    # Scan datasets directory for original files
    for subdir in os.listdir(datasets_dir):
        subdir_path = os.path.join(datasets_dir, subdir)
        if os.path.isdir(subdir_path):
            for filename in os.listdir(subdir_path):
                if filename.endswith('.txt'):
                    base_name = os.path.splitext(filename)[0]
                    # Try to extract an index pattern if present
                    match = re.search(r'(\d+)', base_name)
                    if match:
                        idx = int(match.group(1))
                        index_map[idx] = base_name

    # If no mapping found, try to infer from the results directory
    if not index_map and os.path.exists(llm_results_dir):
        for filename in os.listdir(llm_results_dir):
            if filename.endswith('.json'):
                base_name = os.path.splitext(filename)[0]
                match = re.search(r'(\d+)', base_name)
                if match:
                    idx = int(match.group(1))
                    index_map[idx] = base_name

    return index_map


def evaluate_graphs(ground_truth_dir="ground_truth/voting_results",
                   llm_results_dir="../causal_discovery_methods/LLMs/results",
                   datasets_dir="../../datasets",
                   num_graphs=99):
    """
    Evaluate the LLM-generated graphs against ground truth.
    """

    results = {
        "f1_scores": [],
        "precision_scores": [],
        "recall_scores": [],
        "shd_scores": [],
        "sid_scores": []
    }

    # Try to build a mapping from indices to filenames
    try:
        index_map = build_index_to_filename_map(datasets_dir, llm_results_dir)
        if index_map:
            print(f"Built index to filename mapping with {len(index_map)} entries")
        else:
            print("Could not build index mapping, will try to match files directly")
            index_map = None
    except Exception as e:
        print(f"Error building index mapping: {e}")
        index_map = None

    valid_graphs = 0
    for idx in range(num_graphs):
        try:
            gt_path = os.path.join(ground_truth_dir, f"{idx}.json")

            if not os.path.exists(gt_path):
                print(f"Skipping graph {idx}: ground truth file not found")
                continue

            # Load ground truth
            true_matrix = load_ground_truth(ground_truth_dir, idx)

            # Try to load LLM-generated prediction
            try:
                pred_matrix = load_generated_graph(llm_results_dir, idx, index_map)
            except FileNotFoundError as e:
                print(f"Skipping graph {idx}: {e}")
                continue

            # Ensure matrices have the same shape
            n_true = len(true_matrix)
            n_pred = len(pred_matrix)

            if n_true != n_pred:
                print(f"Graph {idx} has different dimensions (GT: {n_true}, Pred: {n_pred})")
                # Use the smaller dimension
                n = min(n_true, n_pred)
                true_matrix = np.array(true_matrix)[:n, :n]
                pred_matrix = np.array(pred_matrix)[:n, :n]

            # Calculate F1 score
            f1_metrics = calculate_f1(true_matrix, pred_matrix)
            results["f1_scores"].append(f1_metrics["f1"])
            results["precision_scores"].append(f1_metrics["precision"])
            results["recall_scores"].append(f1_metrics["recall"])

            # Calculate SHD
            if cdt_available:
                true_graph = matrix_to_networkx(true_matrix)
                pred_graph = matrix_to_networkx(pred_matrix)
                shd_score = SHD(true_graph, pred_graph, double_for_anticausal=False)
            else:
                shd_score = calculate_shd(true_matrix, pred_matrix)

            results["shd_scores"].append(shd_score)

            # Calculate SID
            true_graph = matrix_to_networkx(true_matrix)
            pred_graph = matrix_to_networkx(pred_matrix)

            try:
                sid_score = calculate_sid(true_graph, pred_graph)
                results["sid_scores"].append(sid_score)
            except Exception as e:
                print(f"Error calculating SID for graph {idx}: {e}")
                sid_score = None
                results["sid_scores"].append(None)

            print(f"Graph {idx} - F1: {f1_metrics['f1']:.4f}, SHD: {shd_score}, SID: {sid_score}")
            valid_graphs += 1

        except Exception as e:
            print(f"Error processing graph {idx}: {e}")

    if valid_graphs == 0:
        print("No valid graph comparisons were made. Please check file paths.")
        return

    avg_f1 = np.mean(results["f1_scores"])
    avg_precision = np.mean(results["precision_scores"])
    avg_recall = np.mean(results["recall_scores"])
    avg_shd = np.mean(results["shd_scores"])

    # Calculate average SID only for valid values
    sid_values = [s for s in results["sid_scores"] if s is not None]
    avg_sid = np.mean(sid_values) if sid_values else None

    print("\nAverage Metrics:")
    print(f"F1 Score: {avg_f1:.4f}")
    print(f"Precision: {avg_precision:.4f}")
    print(f"Recall: {avg_recall:.4f}")
    print(f"SHD: {avg_shd:.4f}")
    if avg_sid is not None:
        print(f"SID: {avg_sid:.4f}")

    print(f"Successfully processed {valid_graphs} graphs")

    # Save results to file
    try:
        results_summary = {
            "metrics": {
                "f1": float(avg_f1),
                "precision": float(avg_precision),
                "recall": float(avg_recall),
                "shd": float(avg_shd),
                "sid": float(avg_sid) if avg_sid is not None else None
            },
            "processed_graphs": valid_graphs,
            "detailed_results": {
                "f1_scores": [float(x) for x in results["f1_scores"]],
                "precision_scores": [float(x) for x in results["precision_scores"]],
                "recall_scores": [float(x) for x in results["recall_scores"]],
                "shd_scores": [float(x) for x in results["shd_scores"]],
                "sid_scores": [float(x) if x is not None else None for x in results["sid_scores"]]
            }
        }

        results_file = os.path.join(os.path.dirname(llm_results_dir), "evaluation_results.json")
        with open(results_file, 'w') as f:
            json.dump(results_summary, f, indent=2)

        print(f"Results saved to {results_file}")
    except Exception as e:
        print(f"Error saving results to file: {e}")


def main():
    """Main function to run the evaluation"""
    base_dir = '.'
    ground_truth_dir = os.path.join(base_dir, "ground_truth", "voting_results")
    llm_results_dir = os.path.join(base_dir, "..", "causal_discovery_methods", "LLMs", "results")
    datasets_dir = os.path.join(base_dir, "..", "..", "datasets")

    # Number of graphs to evaluate
    num_graphs = 1000

    print(f"Evaluating {num_graphs} graphs...")
    print(f"Ground truth directory: {ground_truth_dir}")
    print(f"LLM results directory: {llm_results_dir}")
    print(f"Datasets directory: {datasets_dir}")

    # Run evaluation
    evaluate_graphs(ground_truth_dir, llm_results_dir, datasets_dir, num_graphs)


if __name__ == "__main__":
    main()