
# automata/dfa.py
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional, Iterable, Set
from collections import deque, defaultdict
import numpy as np


Symbol = str
State = int


def _lenlex(syms: Iterable[Symbol]) -> Tuple[Symbol, ...]:
    """Length-lex order: first by length, then lexical."""
    return tuple(sorted(syms, key=lambda s: (len(s), s)))


@dataclass(frozen=True)
class DFA:
    """
    Canonical minimized DFA.
    Invariants (enforced on construction):
      - sigma is length-lex sorted.
      - DFA is trimmed, minimized (Hopcroft), and BFS-renumbered (sigma order).
      - dead is None or the unique sink (non-accepting, self-looping on all symbols).
    """

    sigma: Tuple[Symbol, ...]
    start: State
    finals: Tuple[State, ...]
    delta: Tuple[Tuple[State, ...], ...]
    dead: Optional[State] = None

    # ---------- lifecycle ----------
    def __post_init__(self) -> None:
        # Normalize column order to length-lex, validate, then canonicalize via minimize()
        n = len(self.delta)
        if n == 0:
            raise ValueError("DFA must have ≥1 state")

        # Sort sigma and reorder delta columns
        old_sigma = tuple(self.sigma)
        m = len(old_sigma)
        for row in self.delta:
            if len(row) != m:
                raise ValueError("delta row length must equal len(sigma)")
        sigma_sorted = _lenlex(old_sigma)
        colmap = {i_new: old_sigma.index(s) for i_new, s in enumerate(sigma_sorted)}
        delta_cols_sorted = tuple(tuple(row[colmap[j]] for j in range(len(sigma_sorted)))
                                  for row in self.delta)

        if not (0 <= self.start < n):
            raise ValueError("start out of range")
        for q in self.finals:
            if not (0 <= q < n):
                raise ValueError(f"final state {q} out of range")

        # Build a temporary DFA (normalized columns) and minimize it
        tmp = object.__new__(DFA)
        object.__setattr__(tmp, "sigma", sigma_sorted)
        object.__setattr__(tmp, "start", self.start)
        object.__setattr__(tmp, "finals", tuple(sorted(set(self.finals))))
        object.__setattr__(tmp, "delta", delta_cols_sorted)
        object.__setattr__(tmp, "dead", self.dead)
        minimized = tmp._minimize_canonical()

        # Commit canonical minimal fields
        object.__setattr__(self, "sigma", minimized.sigma)
        object.__setattr__(self, "start", minimized.start)
        object.__setattr__(self, "finals", minimized.finals)
        object.__setattr__(self, "delta", minimized.delta)
        object.__setattr__(self, "dead", minimized.dead)
        object.__setattr__(self, "_sym2idx", {s: i for i, s in enumerate(self.sigma)})

    def minimize(self) -> "DFA":
        """
        Return the trimmed, minimized, canonically renumbered DFA.
        (Calling this on an existing DFA is idempotent.)
        """
        return self._minimize_canonical()

    # ---------- core ops ----------
    @property
    def n_states(self) -> int:
        return len(self.delta)

    @property
    def alphabet_size(self) -> int:
        return len(self.sigma)

    def step(self, q: State, sym: Symbol) -> State:
        return self.delta[q][self._sym2idx[sym]]  # type: ignore[attr-defined]

    def run(self, seq: Iterable[Symbol]) -> State:
        q = self.start
        for s in seq:
            q = self.step(q, s)
        return q

    def accepts(self, seq: Iterable[Symbol]) -> bool:
        finals = set(self.finals)
        return self.run(seq) in finals

    # ---------- NSP labels ----------
    def continuation_bits(self, q: State) -> Tuple[int, ...]:
        if self.dead is None:
            return tuple(1 for _ in range(self.alphabet_size))
        dead = self.dead
        return tuple(0 if self.delta[q][a] == dead else 1 for a in range(self.alphabet_size))

    def prefix_membership_labels(self, seq: List[Symbol]) -> List[int]:
        finals = set(self.finals)
        labels = []
        q = self.start
        labels.append(1 if q in finals else 0)
        for s in seq:
            q = self.step(q, s)
            labels.append(1 if q in finals else 0)
        return labels

    def prefix_continuation_labels(self, seq: List[Symbol]) -> List[Tuple[int, ...]]:
        out: List[Tuple[int, ...]] = []
        q = self.start
        out.append(self.continuation_bits(q))
        for s in seq:
            q = self.step(q, s)
            out.append(self.continuation_bits(q))
        return out

    def nsp_matrix(self, seq: List[Symbol]):
        """(N+1) x (|Sigma|+1) array: continuations (sigma len-lex order) then membership."""

        cont = self.prefix_continuation_labels(seq)
        mem = self.prefix_membership_labels(seq)
        N, M = len(mem), self.alphabet_size
        arr = np.zeros((N, M + 1), dtype=int)
        for i in range(N):
            arr[i, :M] = cont[i]
            arr[i, M] = mem[i]
        return arr

    # ---------- comparison ----------
    def is_isomorphic_to(self, other: "DFA") -> bool:
        """
        Canonical minimal DFAs over the same sigma are isomorphic iff (delta, finals) match.
        """
        return (self.sigma == other.sigma
                and self.start == other.start
                and self.delta == other.delta
                and self.finals == other.finals)

    def language_equivalent(self, other: "DFA") -> Tuple[bool, Optional[List[Symbol]]]:
        """
        Check L(self) == L(other); if not, return a shortest counterexample.
        Requires identical sigma (mismatched sigmas are a pipeline bug here).
        """
        if self.sigma != other.sigma:
            raise ValueError("language_equivalent requires identical sigma.")
        if self.delta == other.delta and self.finals == other.finals:
            return True, None

        finals_a, finals_b = set(self.finals), set(other.finals)
        m = self.alphabet_size
        start = (self.start, other.start)
        dq = deque([start])
        parent: Dict[Tuple[int, int], Optional[Tuple[Tuple[int, int], Symbol]]] = {start: None}
        seen = {start}

        while dq:
            qa, qb = dq.popleft()
            if (qa in finals_a) != (qb in finals_b):
                path: List[Symbol] = []
                cur = (qa, qb)
                while parent[cur] is not None:
                    prev, sym = parent[cur]  # type: ignore
                    path.append(sym)
                    cur = prev
                path.reverse()
                return False, path
            for a in range(m):
                na, nb = self.delta[qa][a], other.delta[qb][a]
                pair = (na, nb)
                if pair not in seen:
                    seen.add(pair)
                    parent[pair] = ((qa, qb), self.sigma[a])
                    dq.append(pair)

        return True, None  # structurally different but language-equal would have been caught earlier

    # ---------- visualization ----------
    def to_dot(self, merge_labels: bool = True) -> str:
        lines = [
            "digraph DFA {",
            "  rankdir=LR;",
            "  node [shape=circle];",
            "  __start [shape=point];",
            f"  __start -> s{self.start};",
        ]
        finals = set(self.finals)
        for q in range(self.n_states):
            attrs = []
            if q in finals:
                attrs.append("peripheries=2")
            if self.dead is not None and q == self.dead:
                attrs += ['style=filled', 'fillcolor="#eeeeee"']
            attr = (" [" + ",".join(attrs) + "]") if attrs else ""
            lines.append(f"  s{q}{attr};")
        if merge_labels:
            lab: Dict[Tuple[int, int], List[str]] = defaultdict(list)
            for q in range(self.n_states):
                for a, sym in enumerate(self.sigma):
                    r = self.delta[q][a]
                    lab[(q, r)].append(sym)
            for (q, r), syms in lab.items():
                lines.append(f'  s{q} -> s{r} [label="{",".join(syms)}"];')
        else:
            for q in range(self.n_states):
                for a, sym in enumerate(self.sigma):
                    r = self.delta[q][a]
                    lines.append(f'  s{q} -> s{r} [label="{sym}"];')
        lines.append("}")
        return "\n".join(lines)

    def render(self, filepath_noext: str, fmt: str = "pdf") -> str:
        try:
            import graphviz  # type: ignore
        except Exception:
            path = f"{filepath_noext}.dot"
            with open(path, "w", encoding="utf-8") as f:
                f.write(self.to_dot())
            return path
        src = graphviz.Source(self.to_dot())
        return src.render(filepath_noext, format=fmt, cleanup=True)
    
    def to_dot_simple(self, merge_labels: bool = True, hide_dead: bool = True) -> str:
        """
        DOT for a simplified view that omits ALL transitions INTO the dead state.
        If hide_dead=True, the dead state node is not shown either.
        """
        lines = [
            "digraph DFA {",
            "  rankdir=LR;",
            "  node [shape=circle];",
            "  __start [shape=point];",
            f"  __start -> s{self.start};",
        ]
        finals = set(self.finals)
        dead = self.dead

        # Nodes
        for q in range(self.n_states):
            if hide_dead and (dead is not None) and (q == dead):
                continue
            attrs = []
            if q in finals:
                attrs.append("peripheries=2")
            if (not hide_dead) and (dead is not None) and (q == dead):
                attrs += ['style=filled', 'fillcolor="#eeeeee"']
            attr = (" [" + ",".join(attrs) + "]") if attrs else ""
            lines.append(f"  s{q}{attr};")

        # Edges (skip any transition whose target is the dead state)
        if merge_labels:
            lab: Dict[Tuple[int, int], List[str]] = defaultdict(list)
            for q in range(self.n_states):
                if hide_dead and (dead is not None) and (q == dead):
                    continue
                for a, sym in enumerate(self.sigma):
                    r = self.delta[q][a]
                    if (dead is not None) and (r == dead):
                        continue  # omit edge to dead
                    if hide_dead and (dead is not None) and (r == dead):
                        continue
                    lab[(q, r)].append(sym)
            for (q, r), syms in lab.items():
                # If hide_dead, r cannot be dead; node exists. If not hiding, r exists as well.
                lines.append(f'  s{q} -> s{r} [label="{",".join(syms)}"];')
        else:
            for q in range(self.n_states):
                if hide_dead and (dead is not None) and (q == dead):
                    continue
                for a, sym in enumerate(self.sigma):
                    r = self.delta[q][a]
                    if (dead is not None) and (r == dead):
                        continue
                    if hide_dead and (dead is not None) and (r == dead):
                        continue
                    lines.append(f'  s{q} -> s{r} [label="{sym}"];')

        lines.append("}")
        return "\n".join(lines)

    def render_simple(self, filepath_noext: str, fmt: str = "pdf",
                      *, merge_labels: bool = True, hide_dead: bool = True) -> str:
        """
        Render the simplified DOT (no edges to dead; optionally hide the dead node) to a file.
        Returns the path to the produced file. If graphviz is missing, writes a .dot fallback.
        """
        try:
            import graphviz  # type: ignore
        except Exception:
            path = f"{filepath_noext}.dot"
            with open(path, "w", encoding="utf-8") as f:
                f.write(self.to_dot_simple(merge_labels=merge_labels, hide_dead=hide_dead))
            return path
        src = graphviz.Source(self.to_dot_simple(merge_labels=merge_labels, hide_dead=hide_dead))
        return src.render(filepath_noext, format=fmt, cleanup=True)


    # ---------- internals ----------
    def _trim(self) -> Tuple[int, Tuple[Tuple[int, ...], ...], Tuple[int, ...]]:
        """Remove states unreachable from start; return (start, delta, finals)."""
        n, m = len(self.delta), len(self.sigma)
        start = self.start
        finals = set(self.finals)
        delta = self.delta
        seen = [False] * n
        dq = deque([start])
        seen[start] = True
        order: List[int] = []
        while dq:
            q = dq.popleft()
            order.append(q)
            for a in range(m):
                r = delta[q][a]
                if not seen[r]:
                    seen[r] = True
                    dq.append(r)
        remap = {q: i for i, q in enumerate(order)}
        mdelta = tuple(tuple(remap[delta[q][a]] for a in range(m)) for q in order)
        mfinals = tuple(sorted(remap[q] for q in order if q in finals))
        mstart = remap[start]
        return mstart, mdelta, mfinals

    def _hopcroft(self, start: int, delta: Tuple[Tuple[int, ...], ...],
                  finals: Tuple[int, ...]) -> Tuple[int, Tuple[Tuple[int, ...], ...], Tuple[int, ...]]:
        """Hopcroft's DFA minimization on a trimmed DFA."""
        n, m = len(delta), len(self.sigma)
        F = set(finals)
        NF = set(range(n)) - F
        P: List[Set[int]] = []
        if F:
            P.append(set(F))
        if NF:
            P.append(set(NF))

        preds = [defaultdict(set) for _ in range(m)]
        for p in range(n):
            for a in range(m):
                preds[a][delta[p][a]].add(p)

        W = deque()
        for a in range(m):
            A0 = F if (F and (not NF or len(F) <= len(NF))) else NF
            W.append((a, frozenset(A0)))

        block_of: Dict[int, int] = {}
        for i, B in enumerate(P):
            for q in B:
                block_of[q] = i

        while W:
            a, A = W.popleft()
            X = set().union(*(preds[a][r] for r in A)) if A else set()

            newP: List[Set[int]] = []
            new_block_of: Dict[int, int] = {}
            splits: List[Tuple[Set[int], Set[int]]] = []

            for Y in P:
                Y1 = Y & X
                Y2 = Y - X
                if Y1 and Y2:
                    newP.append(Y1)
                    i1 = len(newP) - 1
                    for q in Y1:
                        new_block_of[q] = i1
                    newP.append(Y2)
                    i2 = len(newP) - 1
                    for q in Y2:
                        new_block_of[q] = i2
                    splits.append((Y1, Y2))
                else:
                    newP.append(Y)
                    idx = len(newP) - 1
                    for q in Y:
                        new_block_of[q] = idx

            P = newP
            block_of = new_block_of

            for Y1, Y2 in splits:
                smaller = Y1 if len(Y1) <= len(Y2) else Y2
                for b in range(m):
                    W.append((b, frozenset(smaller)))

        # Build block automaton
        Bn = len(P)
        Bdelta = [[0] * m for _ in range(Bn)]
        Bfinals: Set[int] = set()
        for b, B in enumerate(P):
            if any(q in F for q in B):
                Bfinals.add(b)
            q0 = next(iter(B))
            for a in range(m):
                Bdelta[b][a] = block_of[delta[q0][a]]
        Bstart = block_of[start]

        # BFS canonicalization by sigma order
        seen = [False] * Bn
        order: List[int] = []
        dq = deque([Bstart])
        seen[Bstart] = True
        while dq:
            b = dq.popleft()
            order.append(b)
            for a in range(m):
                nb = Bdelta[b][a]
                if not seen[nb]:
                    seen[nb] = True
                    dq.append(nb)
        idmap = {b: i for i, b in enumerate(order)}
        N = len(order)
        Cdelta = tuple(tuple(idmap[Bdelta[b][a]] for a in range(m)) for b in order)
        Cfinals = tuple(sorted(idmap[b] for b in order if b in Bfinals))
        Cstart = idmap[Bstart]
        return Cstart, Cdelta, Cfinals

    def _detect_dead(self, delta: Tuple[Tuple[int, ...], ...], finals: Tuple[int, ...]) -> Optional[int]:
        m = len(self.sigma)
        finals_set = set(finals)
        for i, row in enumerate(delta):
            if i not in finals_set and all(row[a] == i for a in range(m)):
                return i
        return None

    def _minimize_canonical(self) -> "DFA":
        """Trim → Hopcroft → BFS-canonicalize → annotate dead → return new DFA."""
        start1, delta1, finals1 = self._trim()
        start2, delta2, finals2 = self._hopcroft(start1, delta1, finals1)
        dead2 = self._detect_dead(delta2, finals2)
        out = object.__new__(DFA)
        object.__setattr__(out, "sigma", self.sigma)
        object.__setattr__(out, "start", start2)
        object.__setattr__(out, "finals", finals2)
        object.__setattr__(out, "delta", delta2)
        object.__setattr__(out, "dead", dead2)
        object.__setattr__(out, "_sym2idx", {s: i for i, s in enumerate(self.sigma)})
        return out
    


    def find_accepting_suffix(self, prefix: List[Symbol]) -> List[Symbol]:
        """
        Return a shortest accepted string that starts with `prefix`.
        - If `prefix` has a symbol not in sigma: ValueError.
        - If `prefix` leads to the dead state: ValueError.
        - If `prefix` already ends in an accepting state: returns `prefix`.
        """
        # Step through prefix (validate symbols)
        q = self.start
        try:
            for s in prefix:
                q = self.step(q, s)
        except KeyError as e:
            raise ValueError(f"symbol {e.args[0]!r} not in DFA alphabet") from None

        # Dead prefix
        if self.dead is not None and q == self.dead:
            raise ValueError("prefix is in dead state")

        # Already accepting: no suffix needed
        if q in set(self.finals):
            return list(prefix)

        # BFS from q to any accepting state (shortest suffix)
        m = self.alphabet_size
        finals = set(self.finals)
        visited = [False] * self.n_states
        parent: List[Optional[Tuple[int, int]]] = [None] * self.n_states  # (prev_state, symbol_idx)
        dq = deque([q])
        visited[q] = True

        target = None
        while dq:
            s = dq.popleft()
            for a in range(m):
                t = self.delta[s][a]
                if not visited[t]:
                    visited[t] = True
                    parent[t] = (s, a)
                    if t in finals:
                        target = t
                        dq.clear()
                        break
                    dq.append(t)

        # In a canonical minimal DFA, if no dead exists then every state has an accepting continuation;
        # conversely, if no accepting continuation exists, that state must be the unique dead.
        if target is None:
            raise RuntimeError("no accepting continuation found from a non-dead state (should not happen)")

        # Reconstruct suffix symbols
        suffix: List[Symbol] = []
        cur = target
        while cur != q:
            prev, aidx = parent[cur]  # type: ignore
            suffix.append(self.sigma[aidx])
            cur = prev
        suffix.reverse()

        return list(prefix) + suffix

    # ---------- vocabulary-superset comparison ----------
    def language_equivalent_allow_superset(self, other: "DFA") -> Tuple[bool, Optional[List[Symbol]]]:
        """
        Compare languages allowing one alphabet to be a strict superset of the other.
        If sigma differ and one is a superset of the other, lift the smaller DFA to the larger
        by routing every extra symbol to a dead sink (from every state), then compare.
        If neither is a superset, raise ValueError.
        """
        S = set(self.sigma)
        T = set(other.sigma)

        if self.sigma == other.sigma:
            return self.language_equivalent(other)

        # Decide which way to lift
        if S.issubset(T):
            lifted_self = self._lift_to_superset(tuple(other.sigma))
            return lifted_self.language_equivalent(other)
        elif T.issubset(S):
            lifted_other = other._lift_to_superset(tuple(self.sigma))
            return self.language_equivalent(lifted_other)
        else:
            raise ValueError(
                "Alphabets are incomparable (neither is a superset). "
                "Unify your vocabularies before comparison."
            )

    # ---------- internals ----------
    def _lift_to_superset(self, target_sigma: Tuple[Symbol, ...]) -> "DFA":
        """
        Return a DFA over `target_sigma` (which must be a length-lex superset of `self.sigma`),
        where any symbol in target_sigma \ self.sigma leads to a dead sink from every state.
        The result is canonicalized/minimized by the constructor.
        """
        # Ensure target_sigma is length-lex sorted
        target_sigma_sorted = _lenlex(target_sigma)
        if target_sigma_sorted != target_sigma:
            target_sigma = target_sigma_sorted

        S = set(self.sigma)
        T = set(target_sigma)
        if not S.issubset(T):
            raise ValueError("target_sigma must be a superset of self.sigma")

        # Prepare maps
        old_idx = {s: i for i, s in enumerate(self.sigma)}
        tgt_idx = {s: i for i, s in enumerate(target_sigma)}

        n = self.n_states
        m_tgt = len(target_sigma)

        # Ensure a dead sink exists in the lifted automaton
        had_dead = (self.dead is not None)
        dead_idx = self.dead if had_dead else n
        n_lift = n if had_dead else (n + 1)

        # Build new delta over target_sigma
        new_delta = [[0] * m_tgt for _ in range(n_lift)]

        # Fill original states
        for q in range(n):
            for s in target_sigma:
                j = tgt_idx[s]
                if s in old_idx:
                    new_delta[q][j] = self.delta[q][old_idx[s]]
                else:
                    new_delta[q][j] = dead_idx  # extra symbol ⇒ go to sink

        # Fill/ensure dead row is self-looping
        if not had_dead:
            for j in range(m_tgt):
                new_delta[dead_idx][j] = dead_idx

        # Finals unchanged (dead is non-final)
        return DFA(
            sigma=target_sigma,
            start=self.start,
            finals=tuple(self.finals),
            delta=tuple(tuple(row) for row in new_delta),
            dead=dead_idx,
        )


def dfa_from_edges(
    num_states: int,
    sigma: List[Symbol],
    start: int,
    finals: Iterable[int],
    edges: Dict[Tuple[int, Symbol], int],
    *,
    complete: bool = True,
    add_dead: bool = True,
) -> DFA:
    """
    Build a DFA from explicit edges, then canonicalize on construction.
    - sigma is sorted to length-lex.
    - If complete=True, any missing transition goes to the unique dead (created if needed).
    """
    sigma_sorted = _lenlex(sigma)
    s2i = {s: i for i, s in enumerate(sigma_sorted)}
    m = len(sigma_sorted)

    if not (0 <= start < num_states):
        raise ValueError("start out of range")
    finals_set = set(finals)
    for f in finals_set:
        if not (0 <= f < num_states):
            raise ValueError(f"final {f} out of range")

    delta = [[-1] * m for _ in range(num_states)]
    for (q, s), r in edges.items():
        if s not in s2i:
            raise ValueError(f"symbol {s!r} not in sigma")
        if not (0 <= q < num_states) or not (0 <= r < num_states):
            raise ValueError("edge has out-of-range state")
        delta[q][s2i[s]] = r

    dead: Optional[int] = None
    if complete:
        need_dead = any(delta[q][a] == -1 for q in range(num_states) for a in range(m))
        if need_dead or add_dead:
            dead = num_states
            num_states += 1
            delta.append([dead] * m)
        for q in range(num_states):
            for a in range(m):
                if delta[q][a] == -1:
                    delta[q][a] = dead  # type: ignore[assignment]
    else:
        for q in range(num_states):
            for a in range(m):
                if delta[q][a] == -1:
                    raise ValueError("incomplete transition and complete=False")

    return DFA(
        sigma=tuple(sigma_sorted),
        start=start,
        finals=tuple(sorted(finals_set)),
        delta=tuple(tuple(row) for row in delta),
        dead=dead,
    )
