"""
Structural Hamming Distance (SHD) and related metrics.

SHD measures the number of edge operations needed to transform
one graph into another.
"""

from __future__ import annotations

from typing import Tuple, Dict, Set, FrozenSet
import numpy as np

from ..core.dag import DAG
from ..core.mec import CPDAG


def structural_hamming_distance(
    cpdag1: CPDAG,
    cpdag2: CPDAG
) -> int:
    """
    Compute Structural Hamming Distance between two CPDAGs.

    SHD counts the number of edge differences:
    - Missing edge: +1
    - Extra edge: +1
    - Wrong direction (i→j vs j→i): +1
    - Directed vs undirected: +1

    Args:
        cpdag1: First CPDAG
        cpdag2: Second CPDAG

    Returns:
        SHD value (0 means identical)
    """
    if cpdag1.num_nodes() != cpdag2.num_nodes():
        raise ValueError("CPDAGs must have same number of nodes")

    d = cpdag1.num_nodes()
    shd = 0

    for i in range(d):
        for j in range(d):
            if i == j:
                continue

            # Get edge status in both CPDAGs
            # Status: 0 = no edge, 1 = directed i->j, 2 = undirected
            status1 = _get_edge_status(cpdag1, i, j)
            status2 = _get_edge_status(cpdag2, i, j)

            if status1 != status2:
                # For undirected pair (i,j) with i < j, only count once
                if i > j and (status1 == 2 or status2 == 2):
                    continue
                shd += 1

    return shd


def _get_edge_status(cpdag: CPDAG, i: int, j: int) -> int:
    """
    Get edge status for pair (i, j).

    Returns:
        0: No edge
        1: Directed edge i -> j
        2: Undirected edge i - j
    """
    if cpdag.has_directed_edge(i, j):
        return 1
    elif cpdag.has_undirected_edge(i, j):
        return 2
    else:
        return 0


def skeleton_hamming_distance(
    cpdag1: CPDAG,
    cpdag2: CPDAG
) -> int:
    """
    Compute skeleton-only Hamming distance.

    Only counts presence/absence of edges, ignoring orientation.

    Args:
        cpdag1: First CPDAG
        cpdag2: Second CPDAG

    Returns:
        Skeleton SHD value
    """
    skeleton1 = cpdag1.skeleton()
    skeleton2 = cpdag2.skeleton()

    # Symmetric difference
    missing_in_2 = skeleton1 - skeleton2
    extra_in_2 = skeleton2 - skeleton1

    return len(missing_in_2) + len(extra_in_2)


def dag_structural_hamming_distance(dag1: DAG, dag2: DAG) -> int:
    """
    Compute SHD between two DAGs (treating as CPDAGs).

    Args:
        dag1: First DAG
        dag2: Second DAG

    Returns:
        SHD value
    """
    cpdag1 = CPDAG.from_dag(dag1)
    cpdag2 = CPDAG.from_dag(dag2)
    return structural_hamming_distance(cpdag1, cpdag2)


def edge_accuracy(
    cpdag_true: CPDAG,
    cpdag_learned: CPDAG
) -> float:
    """
    Compute edge accuracy (fraction of correct edges).

    Args:
        cpdag_true: Ground truth CPDAG
        cpdag_learned: Learned CPDAG

    Returns:
        Accuracy in [0, 1]
    """
    d = cpdag_true.num_nodes()
    total_pairs = d * (d - 1) // 2
    shd = structural_hamming_distance(cpdag_true, cpdag_learned)

    if total_pairs == 0:
        return 1.0

    return 1.0 - shd / total_pairs


def edge_precision_recall(
    cpdag_true: CPDAG,
    cpdag_learned: CPDAG
) -> Tuple[float, float, float]:
    """
    Compute precision, recall, and F1 for edges.

    Considers edge presence in skeleton (ignoring orientation).

    Args:
        cpdag_true: Ground truth CPDAG
        cpdag_learned: Learned CPDAG

    Returns:
        Tuple of (precision, recall, f1)
    """
    skeleton_true = cpdag_true.skeleton()
    skeleton_learned = cpdag_learned.skeleton()

    true_positives = len(skeleton_true & skeleton_learned)
    false_positives = len(skeleton_learned - skeleton_true)
    false_negatives = len(skeleton_true - skeleton_learned)

    precision = true_positives / (true_positives + false_positives) \
        if (true_positives + false_positives) > 0 else 0.0

    recall = true_positives / (true_positives + false_negatives) \
        if (true_positives + false_negatives) > 0 else 0.0

    f1 = 2 * precision * recall / (precision + recall) \
        if (precision + recall) > 0 else 0.0

    return precision, recall, f1


def orientation_accuracy(
    cpdag_true: CPDAG,
    cpdag_learned: CPDAG
) -> float:
    """
    Compute accuracy of edge orientations.

    Only considers edges present in both skeletons.

    Args:
        cpdag_true: Ground truth CPDAG
        cpdag_learned: Learned CPDAG

    Returns:
        Orientation accuracy in [0, 1]
    """
    skeleton_true = cpdag_true.skeleton()
    skeleton_learned = cpdag_learned.skeleton()

    common_edges = skeleton_true & skeleton_learned

    if len(common_edges) == 0:
        return 1.0

    correct_orientations = 0

    for edge in common_edges:
        nodes = list(edge)
        i, j = nodes[0], nodes[1]

        # Get orientation in both
        # Check if same type (both directed, both undirected)
        # and same direction if directed
        true_status = _get_edge_status(cpdag_true, i, j)
        learned_status = _get_edge_status(cpdag_learned, i, j)

        if true_status == learned_status:
            correct_orientations += 1
        elif true_status == 0 or learned_status == 0:
            # One has edge, other doesn't - handled by skeleton
            pass
        elif true_status == 2 or learned_status == 2:
            # One is undirected - check reverse direction
            if _get_edge_status(cpdag_true, j, i) == _get_edge_status(cpdag_learned, j, i):
                correct_orientations += 1

    return correct_orientations / len(common_edges)


def v_structure_precision_recall(
    dag_true: DAG,
    cpdag_learned: CPDAG
) -> Tuple[float, float, float]:
    """
    Compute precision and recall for v-structure detection.

    Args:
        dag_true: True DAG
        cpdag_learned: Learned CPDAG

    Returns:
        Tuple of (precision, recall, f1)
    """
    # True v-structures from DAG
    v_struct_true = set(dag_true.v_structures())

    # Detected v-structures from CPDAG
    # A v-structure i -> k <- j in CPDAG means both edges are directed
    v_struct_learned = set()

    for k in range(cpdag_learned.num_nodes()):
        # Get directed parents of k
        directed_parents = [
            p for (p, c) in cpdag_learned.directed_edges
            if c == k
        ]

        # Check pairs for v-structures
        for idx_i, i in enumerate(directed_parents):
            for j in directed_parents[idx_i + 1:]:
                # Check if i and j are non-adjacent
                if not cpdag_learned.is_adjacent(i, j):
                    if i < j:
                        v_struct_learned.add((i, k, j))
                    else:
                        v_struct_learned.add((j, k, i))

    true_positives = len(v_struct_true & v_struct_learned)
    false_positives = len(v_struct_learned - v_struct_true)
    false_negatives = len(v_struct_true - v_struct_learned)

    precision = true_positives / (true_positives + false_positives) \
        if (true_positives + false_positives) > 0 else 0.0

    recall = true_positives / (true_positives + false_negatives) \
        if (true_positives + false_negatives) > 0 else 0.0

    f1 = 2 * precision * recall / (precision + recall) \
        if (precision + recall) > 0 else 0.0

    return precision, recall, f1


def comprehensive_metrics(
    dag_true: DAG,
    cpdag_learned: CPDAG
) -> Dict:
    """
    Compute comprehensive evaluation metrics.

    Args:
        dag_true: Ground truth DAG
        cpdag_learned: Learned CPDAG

    Returns:
        Dict with all metrics
    """
    cpdag_true = CPDAG.from_dag(dag_true)

    shd = structural_hamming_distance(cpdag_true, cpdag_learned)
    skel_shd = skeleton_hamming_distance(cpdag_true, cpdag_learned)
    accuracy = edge_accuracy(cpdag_true, cpdag_learned)
    precision, recall, f1 = edge_precision_recall(cpdag_true, cpdag_learned)
    orient_acc = orientation_accuracy(cpdag_true, cpdag_learned)
    v_prec, v_rec, v_f1 = v_structure_precision_recall(dag_true, cpdag_learned)

    return {
        'shd': shd,
        'skeleton_shd': skel_shd,
        'edge_accuracy': accuracy,
        'edge_precision': precision,
        'edge_recall': recall,
        'edge_f1': f1,
        'orientation_accuracy': orient_acc,
        'v_structure_precision': v_prec,
        'v_structure_recall': v_rec,
        'v_structure_f1': v_f1,
    }
