from __future__ import annotations
from typing import List, Tuple, Dict, Optional, Iterable
from collections import OrderedDict
import pdb
from automata.dfa import DFA, dfa_from_edges
import numpy as np

Symbol = str
Word   = Tuple[Symbol, ...]


# -------------------------------
# Helpers
# -------------------------------

def lenlex(syms: Iterable[Symbol]) -> Tuple[Symbol, ...]:
    """Length-lex order: first by length, then lexical."""
    return tuple(sorted(syms, key=lambda s: (len(s), s)))

def as_word(seq: Iterable[Symbol]) -> Word:
    return tuple(seq)

def cat(u: Word, v: Word) -> Word:
    return u + v




class RowCache:
    """
    Incremental cache of row vectors for access words Q.
    rows_Q[q] == list of bits for row(q) over current T (len == len(T)).
    On T growth, we append exactly one bit per q via MQ(q·t_new).
    """
    def __init__(self, sigma: Tuple[Symbol, ...], mq: "MQ"):
        self.sigma = sigma
        self.mq = mq
        self.rows_Q: Dict[Word, List[int]] = {}

    def ensure_rows_for_Q(self, Q: List[Word], T: List[Word]) -> None:
        """Compute/extend rows for all q∈Q so len(row(q)) == len(T)."""
        for q in Q:
            r = self.rows_Q.get(q)
            if r is None:
                # First time for q: compute full row(q) over current T
                self.rows_Q[q] = [self.mq(cat(q, t)) for t in T]
            else:
                # Extend by new columns only
                for k in range(len(r), len(T)):
                    r.append(self.mq(cat(q, T[k])))

    def build_row_index_from_cache(self, Q: List[Word]) -> Dict[Tuple[int, ...], Word]:
        """
        Build {row_tuple -> representative q} from cached rows.
        Raises if duplicate rows appear (violates separability).
        """
        idx: Dict[Tuple[int, ...], Word] = {}
        for q in Q:
            r_tuple = tuple(self.rows_Q[q])
            if r_tuple in idx:
                raise RuntimeError("Observation table not separable; duplicate rows in Q.")
            idx[r_tuple] = q
        return idx




# -------------------------------
# Membership and Example/Equivalence Oracles
# -------------------------------


class MQ:
    """
    Membership-Query wrapper with a bounded LRU cache.

    Modes (exactly one):
      - teacher_dfa:       MQ(x) = 1{ teacher_dfa.accepts(x) }
      - membership_oracle: MQ(x) = int(membership_oracle.label(x))

    'sigma' must match the teacher DFA's alphabet when teacher_dfa is used.

    Caching:
      - LRU over the last `cache_size` queries (keyed by tuple(tokens)).
      - Enabled by default; disable via use_cache=False or cache_size=0.
    """
    def __init__(
        self,
        sigma: Iterable[Symbol],
        *,
        teacher_dfa: Optional["DFA"] = None,
        membership_oracle: Optional[object] = None,
        cache_size: int = 100_000,
        use_cache: bool = True,
    ):
        if (teacher_dfa is None) == (membership_oracle is None):
            raise ValueError("Provide exactly one of teacher_dfa or membership_oracle.")
        self.sigma = lenlex(sigma)

        # Build the underlying label function
        if teacher_dfa is not None:
            if tuple(teacher_dfa.sigma) != self.sigma:
                raise ValueError("Teacher DFA sigma must match MQ.sigma (same symbols, same order).")
            def _label(seq: Iterable[Symbol]) -> int:
                return 1 if teacher_dfa.accepts(list(seq)) else 0
            self._label = _label
        else:
            if not hasattr(membership_oracle, "label"):
                raise TypeError("membership_oracle must expose a .label(tokens)->int method.")
            def _label(seq: Iterable[Symbol]) -> int:
                return int(membership_oracle.label(list(seq)))  # type: ignore[attr-defined]
            self._label = _label

        # LRU cache setup
        self._use_cache: bool = bool(use_cache and cache_size > 0)
        self._cap: int = max(0, int(cache_size))
        self._cache: "OrderedDict[Tuple[Symbol, ...], int]" = OrderedDict()
        self._hits: int = 0
        self._misses: int = 0
        self._evictions: int = 0

    def __call__(self, tokens: Iterable[Symbol]) -> int:
        """
        Return membership label for `tokens` (Σ-only, no specials).
        Uses LRU cache if enabled.
        """
        if not self._use_cache:
            return self._label(tokens)

        key = tuple(tokens)  # stable, hashable key
        val = self._cache.get(key)
        if val is not None:
            # LRU promote
            self._cache.move_to_end(key, last=True)
            self._hits += 1
            return val

        # Miss: compute and insert
        y = self._label(key)  # pass the tuple; underlying label() converts to list
        self._cache[key] = y
        self._misses += 1

        if len(self._cache) > self._cap:
            # Evict least-recently-used
            self._cache.popitem(last=False)
            self._evictions += 1

        return y

    # ---------- cache utilities ----------
    def clear_cache(self) -> None:
        self._cache.clear()
        self._hits = self._misses = self._evictions = 0

    def cache_info(self) -> Dict[str, int]:
        """
        Return cache diagnostics:
          - size: current number of entries
          - capacity: max entries before eviction
          - hits, misses, evictions: simple counters
        """
        return {
            "size": len(self._cache),
            "capacity": self._cap,
            "hits": self._hits,
            "misses": self._misses,
            "evictions": self._evictions,
        }





class VanillaEQ:
    """
    Equivalence oracle over a fixed labeled set.
    __call__ takes hypothesis DFA and returns a counterexample if found, else None.
    """
    def __init__(self, examples: List[Tuple[List[Symbol], int]]):
        self.examples = [(list(x), int(y)) for x, y in examples]

    def __call__(self, dfa: DFA) -> Optional[List[Symbol]]:
        for x, y in self.examples:
            y_hat = 1 if dfa.accepts(x) else 0
            if y_hat != y:
                return x
        return None



# -------------------------------
# NSP Equivalence Oracle
# -------------------------------

class NspEQ:
    """
    Equivalence oracle over NSP-labeled examples.
    Each example is (x, labels) where:
      - x is a Σ-only list[str]
      - labels is an (N+1) x (|Σ|+1) numpy array, with columns:
            [continuation bits in length-lex order of Σ, membership bit (last col)]
    __call__(A_hat) returns:
      - (x, labels) for the first mismatch (NSP counterexample), or
      - None if all examples agree with A_hat under NSP labels.
    """
    def __init__(self, examples: List[Tuple[List[Symbol], np.ndarray]]):
        self.examples = [(list(x), np.asarray(y, dtype=int)) for x, y in examples]

    def __call__(self, dfa: DFA) -> Optional[Tuple[List[Symbol], np.ndarray]]:
        for x, lab in self.examples:
            pred = dfa.nsp_matrix(x)
            if pred.shape != lab.shape:
                raise ValueError(
                    f"NspEQ: shape mismatch for example of length {len(x)}: "
                    f"pred {pred.shape} vs given {lab.shape}"
                )
            if not np.array_equal(pred, lab):
                return x, lab
        return None


# -------------------------------
# Access/Test-word table utilities
# -------------------------------

def compute_row(w: Word, T: List[Word], mq: MQ) -> Tuple[int, ...]:
    """Row(w) = [ MQ(w·t) for t in T ]."""
    return tuple(mq(cat(w, t)) for t in T)

def build_row_index(Q: List[Word], T: List[Word], mq: MQ) -> Dict[Tuple[int, ...], Word]:
    """Map each distinct row on Q to its representative word in Q (assumes separability)."""
    idx: Dict[Tuple[int, ...], Word] = {}
    for q in Q:
        row = compute_row(q, T, mq)
        if row in idx:
            raise RuntimeError("Observation table not separable; two access words have identical rows.")
        idx[row] = q

    return idx

def rows_equal(u: Word, v: Word, T: List[Word], mq: MQ) -> bool:
    """T-equivalence test via row equality."""
    return compute_row(u, T, mq) == compute_row(v, T, mq)




def find_closure_violation(
    Q: List[Word],
    T: List[Word],
    sigma: Tuple[Symbol, ...],
    mq: "MQ",
    rows: Optional[RowCache] = None,
) -> Optional[Word]:
    """
    If (Q,T) is not closed, return a violating access word q' = q·a.
    If already closed, return None.

    With rows!=None, reuse cached rows for all q∈Q; otherwise fall back to
    the original (uncached) behavior.
    """
    if rows is None:
        # Original behavior (unchanged semantics)
        row_idx = build_row_index(Q, T, mq)
        for q in Q:
            for a in sigma:
                qa = cat(q, (a,))
                if compute_row(qa, T, mq) not in row_idx:
                    return qa
        return None

    # Cached behavior
    rows.ensure_rows_for_Q(Q, T)
    row_idx = rows.build_row_index_from_cache(Q)  # {row(q): q}
    for q in Q:
        for a in sigma:
            qa_row = tuple(mq(cat(cat(q, (a,)), t)) for t in T)  # on-the-fly for qa
            if qa_row not in row_idx:
                return cat(q, (a,))
    return None


# -------------------------------
# Build hypothesis A_(Q,T)
# -------------------------------


def build_hypothesis_dfa(
    Q: List[Word],
    T: List[Word],
    sigma: Tuple[Symbol, ...],
    mq: "MQ",
    rows: Optional[RowCache] = None,
) -> DFA:
    """
    Construct A_(Q,T) from the observation table. If rows is provided, use the
    cached rows for Q to resolve representatives; otherwise fall back to original.
    """
    if rows is None:
        idx_by_row = build_row_index(Q, T, mq)
    else:
        # Reuse cached rows for Q so we don't recompute their rows from scratch.
        rows.ensure_rows_for_Q(Q, T)
        idx_by_row = rows.build_row_index_from_cache(Q)
    id_of = {q: i for i, q in enumerate(Q)}
    edges: Dict[Tuple[int, Symbol], int] = {}
    for q in Q:
        src = id_of[q]
        for a in sigma:
            qa = cat(q, (a,))
            r  = compute_row(qa, T, mq)
            if r not in idx_by_row:
                raise RuntimeError("Observation table not closed; missing representative for q·a.")
            edges[(src, a)] = id_of[idx_by_row[r]]
    start_id = id_of[()]
    finals   = [id_of[q] for q in Q if mq(q) == 1]
    return dfa_from_edges(len(Q), list(sigma), start_id, finals, edges, complete=True, add_dead=False)




# -------------------------------
# Lemma (Counterexample split): add q' and t'
# -------------------------------

def split_counterexample(
    Q: List[Word],
    T: List[Word],
    sigma: Tuple[Symbol, ...],
    mq: MQ,
    w: Word,
) -> Tuple[Word, Word]:
    """
    Given a counterexample w, compute representatives along w by *stepping from Q*:
      q_0 := rep(row(ε))
      q_{i+1} := rep(row(q_i · w_{i+1}))
    Then find i where (q_i,i) is correct and (q_{i+1}, i+1) is incorrect, and return
      q' = q_i · w_{i+1},  t' = w_{i+2:}.

    """
    n = len(w)
    if n == 0:
        return (), ()

    # Build row -> representative map once (Q is separable)
    row_idx = build_row_index(Q, T, mq)

    # Representatives along w via the observation table transitions
    reps: List[Word] = []
    # q_0
    r0 = compute_row((), T, mq)
    if r0 not in row_idx:
        raise RuntimeError("Table not closed at ε (unexpected).")
    qi = row_idx[r0]
    reps.append(qi)

    # Advance by symbols: q_{i+1} = rep(row(q_i · w_{i+1}))
    for i in range(n):
        r_next = compute_row(cat(qi, (w[i],)), T, mq)
        if r_next not in row_idx:
            # This would indicate lack of closure for some q∈Q and a∈Σ.
            raise RuntimeError("Observation table not closed for some q·a during trace.")
        qi = row_idx[r_next]
        reps.append(qi)

    # Find boundary i
    label_w = mq(w)
    for i in range(n):
        # pdb.set_trace()
        qi   = reps[i]
        qi1  = reps[i + 1]
        sfxi  = w[i  :]   # w_{i+1:}
        sfxi1 = w[i + 1 :]   # w_{i+2:}

        correct_i    = (mq(cat(qi,  sfxi )) == label_w)
        incorrect_i1 = (mq(cat(qi1, sfxi1)) != label_w)

        if correct_i and incorrect_i1:
            q_prime = cat(qi, (w[i],))
            t_prime = sfxi1
            return q_prime, t_prime

    # If we get here, EQ/MQ are inconsistent (or counterexample wasn't a true CE)
    raise AssertionError("No split index found — check EQ/MQ consistency.")






def _assert_rows_synced(rows: "RowCache", Q: List[Word], T: List[Word]) -> None:
    """Every cached row(q) must exist and have len == len(T)."""
    L = len(T)
    for q in Q:
        r = rows.rows_Q.get(q)
        if r is None:
            raise AssertionError(f"RowCache missing row for q={q}")
        if len(r) != L:
            raise AssertionError(f"RowCache row length {len(r)} != len(T)={L} for q={q}")


def assert_table_invariants(
    Q: List[Word],
    T: List[Word],
    sigma: Tuple[Symbol, ...],
    mq: "MQ",
    rows: Optional["RowCache"] = None,
) -> None:
    """
    Assert (Q,T) is separable and closed:
      - separable: all rows on Q are distinct,
      - closed:    every q·a has a representative in Q.
    Uses RowCache if provided, else falls back to the original uncached path.
    """
    if rows is None:
        # separable: build_row_index raises on duplicates
        _ = build_row_index(Q, T, mq)
        # closed: no violation
        if find_closure_violation(Q, T, sigma, mq, rows=None) is not None:
            raise AssertionError("(Q,T) not closed.")
        return

    # cached path
    rows.ensure_rows_for_Q(Q, T)
    _assert_rows_synced(rows, Q, T)
    # separable: duplicate rows raise here
    _ = rows.build_row_index_from_cache(Q)
    # closed:
    if find_closure_violation(Q, T, sigma, mq, rows=rows) is not None:
        raise AssertionError("(Q,T) not closed.")