import os
from cdt.metrics import SID, precision_recall
import pandas as pd
import numpy as np
import networkx as nx
from typing import Union
from castle.metrics import MetricsDAG
from causallearn.graph.GeneralGraph import GeneralGraph
### UTILITY FUNCTIONS ###


def eval_metrics(
    save_path: str,
    ground_truth_graph: Union[np.ndarray, nx.DiGraph, GeneralGraph],
    learned_causal_graph: Union[np.ndarray, nx.DiGraph, GeneralGraph],
    dataset_name: str,
    discovery_time: float,
    input_token_count: int,
    output_token_count: int,
    tool_calls: dict,
    model_name: str="",
    tools: list[str]=[],
) -> str:
    """
    Evaluate and save causal discovery metrics.

    Args:
        seed (int): Random seed used for the experiments.
        save_path (str): Path to save the evaluation results.
        ground_truth_graph (np.array or nx.DiGraph): The ground truth causal graph.
        learned_causal_graph (np.array or nx.DiGraph): The learned causal graph from the discovery process.
        exp_name (str): Name of the experiment.
        dataset_name (str): Name of the dataset used.
        discovery_time (float): Time taken for the causal discovery process.
        token_count (int): Number of tokens used in the discovery process.

    Returns:
        file name (str): Path to the saved evaluation results.
    """
    if isinstance(ground_truth_graph, GeneralGraph) and isinstance(
        learned_causal_graph, GeneralGraph
    ):
        from causallearn.graph.ArrowConfusion import ArrowConfusion
        from causallearn.graph.SHD import SHD

        # Add missing nodes from the ground truth graph to the learned graph
        missing_nodes = set(ground_truth_graph.nodes) - set(learned_causal_graph.nodes)
        for node in missing_nodes:
            learned_causal_graph.add_node(node)
        print(
            f"Added {len(missing_nodes)} isolated missing nodes to the learned causal graph."
        )

        # Add missing nodes from the learned graph to the ground truth graph
        extra_nodes = set(learned_causal_graph.nodes) - set(ground_truth_graph.nodes)
        for node in extra_nodes:
            ground_truth_graph.add_node(node)
        print(
            f"Added {len(extra_nodes)} isolated missing nodes to the ground truth graph."
        )

        # For arrows
        arrow = ArrowConfusion(ground_truth_graph, learned_causal_graph)

        tp = arrow.get_arrows_tp()
        fp = arrow.get_arrows_fp()
        fn = arrow.get_arrows_fn()
        tn = arrow.get_arrows_tn()

        precision = arrow.get_arrows_precision()
        recall = arrow.get_arrows_recall()
        f_score = 2 * tp / (2 * tp + fp + fn)
        fdr = fp / (fp + tp) if (fp + tp) > 0 else 0
        tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0

        # Structural Hamming Distance
        shd = SHD(ground_truth_graph, learned_causal_graph).get_shd()

        # Set-up dictionary with metrics
        cd_metrics = {
            "Model": model_name,
            "RAG": 1 if "rag_assistant" in tools else 0,
            "HITL": 1 if "human_in_the_loop" in tools else 0,
            "Dataset": dataset_name,
            "precision": precision,
            "recall": recall,
            "F1": f_score,
            "FDR": fdr,
            "TPR": tpr,
            "FPR": fpr,
            "SHD": shd,
            "Discovery Time": discovery_time,
            "Input Token count": input_token_count,
            "Output Token count": output_token_count,
            "Tool calls": tool_calls
        }

    else:
        from cdt.metrics import SHD as cdt_SHD
        # Clean learned causal graph and ensure all nodes in the ground truth graph are present
        learned_causal_graph = clean_causal_graph(ground_truth_graph, learned_causal_graph)

        # Try-except blocks to handle the case when the metrics are not available
        # This is to ensure that the code does not break if the cdt packages are not available
        try:
            # Structural Hamming Distance (SHD)
            shd = cdt_SHD(ground_truth_graph, learned_causal_graph)
        except Exception as e:
            print(f"SHD computation failed: {e}")
            shd = None

        try:
            # Given the area under precision curve, and the precision-recall curve, compute the precision and recall
            aupr, curve = precision_recall(ground_truth_graph, learned_causal_graph)
            m_dict = MetricsDAG(nx.adjacency_matrix(learned_causal_graph).todense(), nx.adjacency_matrix(ground_truth_graph).todense())
            precision = m_dict.metrics["precision"]
            recall = m_dict.metrics["recall"]
            f_score = m_dict.metrics["F1"]
            fdr = m_dict.metrics["fdr"]
            tpr = m_dict.metrics["tpr"]
            fpr = m_dict.metrics["fpr"]
        except Exception as e:
            print(f"Precision-recall computation failed: {e}")
            aupr = precision = recall = f_score = None

        try:
            # Structural Intervention Distance (SID)
            sid_value = SID(ground_truth_graph, learned_causal_graph)
        except Exception as e:
            print(f"SID computation failed: {e}")
            sid_value = None

        # Set-up dictionary with metrics
        cd_metrics = {            
            "Model": model_name,
            "RAG": 1 if "rag_assistant" in tools else 0,
            "HITL": 1 if "human_in_the_loop" in tools else 0,
            "Dataset": dataset_name,
            "precision": precision,
            "recall": recall,
            "F1": f_score,
            "FDR": fdr,
            "TPR": tpr,
            "FPR": fpr,
            "AUPR": aupr,
            "SHD": shd,
            "Discovery Time": discovery_time,
            "SID": sid_value,
            "Input Token count": input_token_count,
            "Output Token count": output_token_count,
            "Tool calls": tool_calls
        }

    # Print metrics
    print(cd_metrics)

    # Convert to dataframe and save
    cd_metrics_df = pd.DataFrame([cd_metrics])
    cd_metrics_df.to_csv(
        # os.path.join(save_path, f"cd_metrics_RS{seed}.csv"), index=False
        os.path.join(save_path, f"cd_metrics.csv"),
        index=False,
    )
    return os.path.join(save_path, f"cd_metrics.csv")

def clean_causal_graph(ground_truth_graph, learned_causal_graph):
    if isinstance(ground_truth_graph, nx.DiGraph) and isinstance(
        learned_causal_graph, nx.DiGraph
    ):
        missing_nodes = set(ground_truth_graph.nodes) - set(learned_causal_graph.nodes)
        for node in missing_nodes:
            learned_causal_graph.add_node(node)
        print(
            f"Added {len(missing_nodes)} isolated missing nodes to the learned causal graph."
        )
    elif isinstance(ground_truth_graph, np.ndarray) and isinstance(
        learned_causal_graph, np.ndarray
    ):
        if ground_truth_graph.shape[0] > learned_causal_graph.shape[0]:
            diff = ground_truth_graph.shape[0] - learned_causal_graph.shape[0]
            learned_causal_graph = np.pad(
                learned_causal_graph,
                ((0, diff), (0, diff)),
                mode="constant",
                constant_values=0,
            )

    def remove_duplicates(input_list):
        return list(set(input_list))

    # Adapt the order of nodes 
    g_reordered = nx.DiGraph()
    g_reordered.add_nodes_from(ground_truth_graph.nodes())
    g_reordered.add_edges_from(remove_duplicates(learned_causal_graph.edges()))
    learned_causal_graph = g_reordered

    edges_to_remove = []
    for node in learned_causal_graph.nodes():
        if node not in ground_truth_graph.nodes():
            edges_to_remove.append(node)
    for node in edges_to_remove:
        learned_causal_graph.remove_node(node)
    return learned_causal_graph

def intermediate_eval_metrics(ground_truth_graph: Union[np.ndarray, nx.DiGraph, GeneralGraph], learned_causal_graph: Union[np.ndarray, nx.DiGraph, GeneralGraph])->dict:
    """
    Evaluate causal discovery metrics and return a dictionary containing
    precision, recall, F1, TPR, FPR, AUPR, SHD, and SID.

    For GeneralGraph objects, missing nodes are added based on their names.
    For other graph types (nx.DiGraph, np.ndarray), a cleaning function is called.

    Returns:
        A dictionary with computed metrics.
    """
    # Initialize default metric values
    precision = recall = f1 = tpr = fpr = aupr = shd = sid_value = None

    if isinstance(ground_truth_graph, GeneralGraph) and isinstance(learned_causal_graph, GeneralGraph):
        from causallearn.graph.ArrowConfusion import ArrowConfusion
        from causallearn.graph.SHD import SHD

        # Synchronize nodes using node names
        learned_nodes = {node.get_name() for node in learned_causal_graph.get_nodes()}
        ground_truth_nodes = {node.get_name() for node in ground_truth_graph.get_nodes()}

        # Add nodes missing in learned graph
        missing_node_names = ground_truth_nodes - learned_nodes
        for node in ground_truth_graph.get_nodes():
            if node.get_name() in missing_node_names:
                learned_causal_graph.add_node(node)
        # Add nodes missing in ground truth graph
        extra_node_names = learned_nodes - ground_truth_nodes
        for node in learned_causal_graph.get_nodes():
            if node.get_name() in extra_node_names:
                ground_truth_graph.add_node(node)

        # Compute arrow-based metrics
        arrow = ArrowConfusion(ground_truth_graph, learned_causal_graph)
        tp = arrow.get_arrows_tp()
        fp = arrow.get_arrows_fp()
        fn = arrow.get_arrows_fn()
        tn = arrow.get_arrows_tn()

        precision = arrow.get_arrows_precision()
        recall = arrow.get_arrows_recall()
        f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else None
        tpr = tp / (tp + fn) if (tp + fn) > 0 else None
        fpr = fp / (fp + tn) if (fp + tn) > 0 else None

        # Compute SHD
        shd = SHD(ground_truth_graph, learned_causal_graph).get_shd()

        # Attempt to compute SID and AUPR via cdt.metrics; use None if unavailable
        try:
            from cdt.metrics import SID, precision_recall
            sid_value = SID(ground_truth_graph, learned_causal_graph)
        except Exception as e:
            print(f"SID computation failed: {e}")
            sid_value = None

        try:
            aupr, _ = precision_recall(ground_truth_graph, learned_causal_graph)
        except Exception as e:
            print(f"AUPR computation failed: {e}")
            aupr = None

    else:
        # For other graph types (nx.DiGraph or np.ndarray), perform minimal cleaning
        # so that all nodes in the ground truth graph exist in the learned graph.
        learned_causal_graph = clean_causal_graph(ground_truth_graph, learned_causal_graph)
        try:
            from cdt.metrics import SHD, SID, precision_recall
            shd = SHD(ground_truth_graph, learned_causal_graph)
            sid_value = SID(ground_truth_graph, learned_causal_graph)
            aupr, _ = precision_recall(ground_truth_graph, learned_causal_graph)
        except Exception as e:
            print(f"Metrics computation failed: {e}")

        try:
            from castle.metrics import MetricsDAG
            m_dict = MetricsDAG(
                nx.adjacency_matrix(learned_causal_graph).todense(),
                nx.adjacency_matrix(ground_truth_graph).todense()
            )
            precision = m_dict.metrics["precision"]
            recall = m_dict.metrics["recall"]
            f1 = m_dict.metrics["F1"]
            tpr = m_dict.metrics["tpr"]
            fpr = m_dict.metrics["fpr"]
        except Exception as e:
            print(f"Precision-recall metrics failed: {e}")

    # Prepare and return metrics dictionary
    cd_metrics = {
        "precision": precision,
        "recall": recall,
        "F1": f1,
        "TPR": tpr,
        "FPR": fpr,
        "AUPR": aupr,
        "SHD": shd,
        "SID": sid_value
    }
    print(cd_metrics)
    return cd_metrics