# automata/dfa_diff.py

from __future__ import annotations
from typing import List, Tuple, Optional, Dict
from collections import deque

from automata.dfa import DFA, Symbol, _lenlex


def symmetric_difference_dfa(A: DFA, A_star: DFA) -> DFA:
    """
    Return the DFA B for L(A) Δ L(A').
    - Alphabets are unified to their length-lex union via lifting.
    - Only the reachable portion of the product automaton is built.
    - The returned DFA is canonical/minimized by DFA's constructor.
    """
    union_sigma = tuple(set(A.sigma) | set(A_star.sigma))
    A_lift = A._lift_to_superset(union_sigma)
    A_star_lift = A_star._lift_to_superset(union_sigma)
    sigma = A_lift.sigma  # length-lex

    m = len(sigma)
    start_pair = (A_lift.start, A_star_lift.start)

    pair_to_id: Dict[Tuple[int, int], int] = {start_pair: 0}
    pairs: List[Tuple[int, int]] = [start_pair]
    delta: List[List[int]] = [[0] * m]
    q = deque([0])

    while q:
        i = q.popleft()
        qa, qb = pairs[i]
        for a in range(m):
            na = A_lift.delta[qa][a]
            nb = A_star_lift.delta[qb][a]
            pair = (na, nb)
            j = pair_to_id.get(pair)
            if j is None:
                j = len(pairs)
                pair_to_id[pair] = j
                pairs.append(pair)
                delta.append([0] * m)
                q.append(j)
            delta[i][a] = j

    finals_a, finals_b = set(A_lift.finals), set(A_star_lift.finals)
    finals = tuple(sorted(
        i for i, (qa, qb) in enumerate(pairs)
        if ((qa in finals_a) ^ (qb in finals_b))
    ))

    return DFA(
        sigma=sigma,
        start=0,
        finals=finals,
        delta=tuple(tuple(row) for row in delta),
        dead=None,  # canonical re-detection during DFA ctor
    )


# ---------- New, robust symmetric-difference enumeration ----------

def _build_reachable_product(A: DFA, A_star: DFA):
    """
    Build only the reachable portion of the product automaton of A and A_star
    over the union alphabet (length-lex sorted).
    Returns:
      sigma            -- tuple of symbols (length-lex)
      delta            -- list[n_states][|sigma|] -> next state id
      pairs            -- list of (state_in_A, state_in_A_star) per product id
      finals_mask      -- list[bool] whether product state is accepting in XOR sense
      finals_a, finals_b -- sets of accepting states in the lifted DFAs
    """
    union_sigma = tuple(set(A.sigma) | set(A_star.sigma))
    A_l = A._lift_to_superset(union_sigma)
    B_l = A_star._lift_to_superset(union_sigma)
    sigma = A_l.sigma
    m = len(sigma)

    start_pair = (A_l.start, B_l.start)
    pair_to_id: Dict[Tuple[int, int], int] = {start_pair: 0}
    pairs: List[Tuple[int, int]] = [start_pair]
    delta: List[List[int]] = []
    q = deque([0])

    while q:
        i = q.popleft()
        qa, qb = pairs[i]
        row: List[int] = []
        for a in range(m):
            na = A_l.delta[qa][a]
            nb = B_l.delta[qb][a]
            pair = (na, nb)
            j = pair_to_id.get(pair)
            if j is None:
                j = len(pairs)
                pair_to_id[pair] = j
                pairs.append(pair)
                q.append(j)
            row.append(j)
        delta.append(row)

    finals_a, finals_b = set(A_l.finals), set(B_l.finals)
    finals_mask = [((qa in finals_a) ^ (qb in finals_b)) for (qa, qb) in pairs]

    return sigma, delta, pairs, finals_mask, finals_a, finals_b


def _min_to_accept(delta: List[List[int]], finals_mask: List[bool]) -> List[int]:
    """h[u] = shortest distance from u to ANY accepting product state (∞ if none)."""
    n = len(delta)
    m = len(delta[0]) if n else 0
    rev: List[List[int]] = [[] for _ in range(n)]
    for u in range(n):
        for a in range(m):
            v = delta[u][a]
            rev[v].append(u)
    INF = 10**12
    h = [INF] * n
    dq = deque([i for i, acc in enumerate(finals_mask) if acc])
    for i in dq:
        h[i] = 0
    while dq:
        v = dq.popleft()
        for u in rev[v]:
            if h[u] > h[v] + 1:
                h[u] = h[v] + 1
                dq.append(u)
    return h


def _reachable_from_start(delta: List[List[int]], start: int = 0) -> List[bool]:
    n = len(delta)
    seen = [False] * n
    dq = deque([start])
    seen[start] = True
    while dq:
        u = dq.popleft()
        for v in delta[u]:
            if not seen[v]:
                seen[v] = True
                dq.append(v)
    return seen


def _has_live_cycle(delta: List[List[int]], live: List[bool], start: int = 0) -> bool:
    """
    Detect a cycle in the subgraph of nodes that are both reachable from start and can
    reach an accepting state (live = True).
    """
    n = len(delta)
    m = len(delta[0]) if n else 0

    reach = _reachable_from_start(delta, start)
    in_sub = [live[i] and reach[i] for i in range(n)]

    color = [0] * n  # 0=unvisited, 1=visiting, 2=done

    def dfs(u: int) -> bool:
        color[u] = 1
        for a in range(m):
            v = delta[u][a]
            if not in_sub[v]:
                continue
            if color[v] == 1:
                return True
            if color[v] == 0 and dfs(v):
                return True
        color[u] = 2
        return False

    for u in range(n):
        if in_sub[u] and color[u] == 0:
            if dfs(u):
                return True
    return False


def _enumerate_exact_length(
    sigma: Tuple[Symbol, ...],
    delta: List[List[int]],
    finals_mask: List[bool],
    h: List[int],
    L: int,
    want: int,
) -> List[List[Symbol]]:
    """
    Enumerate up to 'want' accepted strings of EXACT length L,
    in lex order, using the h-based pruning h[next] <= L - depth - 1.
    """
    m = len(sigma)
    out: List[List[Symbol]] = []
    buf: List[Symbol] = []

    def dfs(u: int, depth: int) -> None:
        if len(out) >= want:
            return
        if depth == L:
            if finals_mask[u]:
                out.append(list(buf))
            return
        # explore in sigma order for lex order
        rem = L - (depth + 1)
        for a in range(m):
            v = delta[u][a]
            if h[v] <= rem:  # only if completion is possible within L
                buf.append(sigma[a])
                dfs(v, depth + 1)
                buf.pop()
                if len(out) >= want:
                    return

    dfs(0, 0)
    return out


def symmetric_difference_examples(
    A: DFA, A_star: DFA, N: int
) -> Optional[Tuple[List[List[Symbol]], List[Tuple[int, int]], Optional[int]]]:
    """
    - If L(A) == L(A'), return None.
    - Else:
        * Enumerate up to N shortest strings in the symmetric difference (length-lex).
        * For each string, provide labels (A(w), A'(w)) as (0/1, 0/1).
        * If the symmetric-difference language is finite, also return its total size.
    Returns:
        None
        or (strings, labels, total_if_finite)
    """
    eq, _ = A.language_equivalent_allow_superset(A_star)
    if eq:
        return None

    # Build reachable product over the union alphabet (no minimization).
    sigma, delta, pairs, finals_mask, finals_a, finals_b = _build_reachable_product(A, A_star)

    # Precompute shortest distance-to-accept for pruning & min length.
    h = _min_to_accept(delta, finals_mask)
    INF = 10**12
    if h[0] >= INF:
        # Shouldn't happen because eq==False implies some witness exists,
        # but guard anyway.
        return [], [], 0

    strings: List[List[Symbol]] = []
    labels: List[Tuple[int, int]] = []

    # Length-by-length enumeration (short-lex).
    L = h[0]  # minimal possible length
    while len(strings) < N:
        need = N - len(strings)
        chunk = _enumerate_exact_length(sigma, delta, finals_mask, h, L, need)
        if chunk:
            strings.extend(chunk)
            # Compute labels directly from end product state after each word.
            # Walk each word to get the end product state (cheap: |w| steps).
            for w in chunk:
                u = 0
                for sym in w:
                    a = sigma.index(sym)  # sigma is small; for speed cache a map if needed
                    u = delta[u][a]
                qa, qb = pairs[u]
                labels.append((1 if qa in finals_a else 0, 1 if qb in finals_b else 0))
        L += 1

    # Compute total_if_finite:
    # finite iff no cycle in the subgraph of nodes that are live (h[u] < INF)
    # and reachable from the start. If finite, count all words via DAG DP.
    live = [h[i] < INF for i in range(len(delta))]
    if _has_live_cycle(delta, live, start=0):
        total_if_finite = None
    else:
        # DAG: count all accepted words by memoized DFS over live nodes.
        m = len(sigma)
        memo: Dict[int, int] = {}

        def count(u: int) -> int:
            if u in memo:
                return memo[u]
            total = 1 if finals_mask[u] else 0
            for a in range(m):
                v = delta[u][a]
                if live[v]:
                    total += count(v)
            memo[u] = total
            return total

        total_if_finite = count(0)

    return strings, labels, total_if_finite


