
# automata/lstar_nsp.py

from __future__ import annotations
from typing import List, Tuple, Iterable, Optional
import numpy as np

from automata.dfa import DFA
from automata.Lstar_utils import (
    MQ, lenlex, as_word,
    find_closure_violation,
    split_counterexample,
    build_hypothesis_dfa,
    RowCache, assert_table_invariants,
)




Symbol = str
Word   = Tuple[Symbol, ...]

# -------------------------------
# NSP Example/Equivalence Oracle
# -------------------------------

class NspEQ:
    """
    NSP EX/EQ over (x, labels) pairs.

    Assumptions on labels for Σ-only x of length N:
      - shape: (N+1, |Σ|+1)
      - columns: [continuations in length-lex order of Σ, membership bit (last col)]

    The constructor takes the expected Σ (to enforce column order).
    At call time, we also enforce hypothesis.dfa.sigma == expected Σ.
    """

    def __init__(self, examples: List[Tuple[List[Symbol], np.ndarray]], sigma: Iterable[Symbol]):
        self.examples = [(list(x), np.asarray(y, dtype=int)) for x, y in examples]
        self.sigma: Tuple[Symbol, ...] = lenlex(sigma)

    def __call__(self, dfa: DFA) -> Optional[Tuple[List[Symbol], np.ndarray]]:
        # Enforce Σ agreement before checking labels
        if tuple(dfa.sigma) != self.sigma:
            raise ValueError(
                "NspEQ: DFA sigma does not match EQ sigma. "
                f"DFA: {dfa.sigma} vs EQ: {self.sigma}"
            )
        M = len(self.sigma) + 1
        for x, lab in self.examples:
            if lab.shape != (len(x) + 1, M):
                raise ValueError(
                    f"NspEQ: label shape {lab.shape} incompatible with (N+1, |Σ|+1) "
                    f"for N={len(x)}, |Σ|={M-1}"
                )
            pred = dfa.nsp_matrix(x)
            if not np.array_equal(pred, lab):
                return x, lab  # NSP counterexample
        return None


# -------------------------------
# Internal helpers for NSP mismatch handling
# -------------------------------

def _nsp_first_mismatch(dfa: DFA, x: List[Symbol], labels: np.ndarray) -> Tuple[str, int, Optional[int]]:
    """
    Find earliest prefix index r where NSP labels disagree.
    Returns (kind, r, j) where:
      - kind in {"membership", "continuation"}
      - r: prefix index (0..N)
      - j: continuation column (0..|Σ|-1) if kind=="continuation", else None
    """
    pred = dfa.nsp_matrix(x)
    if pred.shape != labels.shape:
        raise ValueError("nsp_first_mismatch: shape mismatch")

    
    Np1, M1 = labels.shape
    M = M1 - 1
    
    for r in range(Np1):
        # membership last column
        if pred[r, M] != labels[r, M]:
            return ("membership", r, None)
    # for r in range(Np1):
        # continuations in Σ order
        for j in range(M):
            if pred[r, j] != labels[r, j]:
                return ("continuation", r, j)
    raise AssertionError("Called on a supposed NSP counterexample but no mismatch found.")


# -------------------------------
# L*‑NSP learner
# -------------------------------

class LStarNSP:
    """
    L*‑NSP (access/test words) with:
      - MQ: DFA teacher or LM MembershipOracle (same MQ as L*)
      - EQ: NspEQ over (x, NSP labels) pairs
      - prefix_sampler: required for B2 (e.g., model_src.oracles.LMPrefixEQ)

    Σ is fixed in length-lex order and must match EQ.sigma.
    """

    def __init__(
        self,
        sigma: Iterable[Symbol],
        mq: MQ,
        eq: NspEQ,
        *,
        prefix_sampler: Optional[object] = None,
        gen_max_len: int = 100,
        prefix_max_steps: int = 256,
    ):
        self.sigma = lenlex(sigma)
        self.mq = mq
        self.eq = eq
        self.prefix_sampler = prefix_sampler
        self.gen_max_len = gen_max_len
        self.prefix_max_steps = prefix_max_steps

        # Hard check that learner Σ == EQ Σ
        if tuple(self.sigma) != eq.sigma:
            raise ValueError(f"LStarNSP: learner sigma {self.sigma} != EQ sigma {eq.sigma}")

    def _consistent(self, dfa: DFA, x: List[Symbol], labels: np.ndarray) -> bool:
        return np.array_equal(dfa.nsp_matrix(x), labels)

    def learn(self, verbose: bool = True) -> DFA:
        # Initialize observation table
        Q: List[Word] = [()]   # access words (ε included)
        T: List[Word] = [()]   # test words   (ε included)

        rows = RowCache(self.sigma, self.mq)
        # rows = None

        mem_d = 0
        b1_d = 0
        b2_d = 0
        total_d = 0

        # Closure before first hypothesis
        while True:
            q_new = find_closure_violation(Q, T, self.sigma, self.mq, rows=rows)
            if q_new is None:
                break
            if q_new not in Q:
                Q.append(q_new)

        assert find_closure_violation(Q, T, self.sigma, self.mq, rows=rows) is None
        assert_table_invariants(Q, T, self.sigma, self.mq, rows=rows)
        A_hat = build_hypothesis_dfa(Q, T, self.sigma, self.mq, rows=rows)
        if verbose:
            print(f"[L*NSP] initial hypothesis constructed")
        
        # Main loop
        cx = 0
        while True:
            # Visualize current hypothesis
            

            ce = self.eq(A_hat)  # -> (x, NSP labels) or None
            if ce is None:
                print(f"[L*NSP]Terminating: No counterexamples found after {cx} counterexamples.")
                assert total_d == mem_d + b1_d + b2_d
                if total_d != 0:
                    b1_frac = b1_d / total_d
                    b2_frac = b2_d / total_d
                    mem_frac = mem_d / total_d
                else:
                    b1_frac = b2_frac = mem_frac = 0.0
                print(f"  Total refinements: {total_d} = {mem_frac} (membership) + {b1_frac} (B1) + {b2_frac} (B2)")
                meta = {
                    "counterexamples": cx,
                    "total_refinements": total_d,
                    "membership_refinements": mem_d,
                    "b1_refinements": b1_d,
                    "b2_refinements": b2_d,
                    "b1_fraction": b1_frac,
                    "b2_fraction": b2_frac,
                    "membership_fraction": mem_frac,
                }
                return A_hat, meta

            
            x, labels = ce
            x = list(x)  # Σ-only
            
            cx += 1
            if verbose:
                print(f"No. {cx} counterexample: " + " ".join(x))

            # Refine until the original NSP counterexample is consistent
            while not self._consistent(A_hat, x, labels):
                # show_dfa_in_terminal(A_hat, 50, 20)
                kind, r, j = _nsp_first_mismatch(A_hat, x, labels)

                # Convert NSP CE -> *membership* counterexample w'
                if kind == "membership":
                    # Membership case
                    mem_d += 1
                    total_d += 1
                    if verbose:
                        print("Found membership mismatch")
                    w_prime = x[:r]  # prefix x[:r]
                else:
                    # Continuation case
                    assert j is not None
                    sym = A_hat.sigma[j]
                    prefix = x[:r] + [sym]

                    tgt_bit = int(labels[r, j])
                    hyp_bit = int(A_hat.nsp_matrix(x)[r, j])

                    if tgt_bit == 0 and hyp_bit == 1:
                        # B1: hypothesis admits, teacher forbids -> find positive example in Â
                        b1_d += 1
                        total_d += 1
                        if verbose:
                            print("Found continuation mismatch, case B1")
                        w_prime = A_hat.find_accepting_suffix(prefix)
                    elif tgt_bit == 1 and hyp_bit == 0:
                        # B2: teacher admits, hypothesis forbids -> ask prefix sampler
                        b2_d += 1
                        total_d += 1
                        if verbose:
                            print("Found continuation mismatch, case B2")
                        if self.prefix_sampler is None:
                            raise RuntimeError(
                                "B2 encountered but no prefix_sampler provided. "
                                "Pass model_src.oracles.LMPrefixEQ (or equivalent)."
                            )
                        w_prime = self.prefix_sampler.sample(prefix, max_len=self.gen_max_len, max_steps=self.prefix_max_steps, verbose=verbose)
                    else:
                        raise AssertionError("Continuation mismatch must be one of (tgt,hyp)=(0,1) or (1,0).")

                # Standard L* split on membership counterexample w'
                q_prime, t_prime = split_counterexample(Q, T, self.sigma, self.mq, as_word(w_prime))
                if q_prime not in Q:
                    Q.append(q_prime)
                if t_prime not in T:
                    T.append(t_prime)

                # Re-close and rebuild
                if verbose:
                    print(f"[L*NSP] Closure process...")
                while True:
                    q_new = find_closure_violation(Q, T, self.sigma, self.mq, rows=rows)
                    if q_new is None:
                        if verbose:
                            print(f"[L*NSP] closure process finished")
                        break
                    if q_new not in Q:
                        Q.append(q_new)
                        if verbose:
                            print(f"[L*NSP] Found new access word no. {len(Q)}")


                assert_table_invariants(Q, T, self.sigma, self.mq, rows=rows)
                A_hat = build_hypothesis_dfa(Q, T, self.sigma, self.mq, rows=rows)
 
 
                if verbose:
                    print(f"[L*NSP] refined hypothesis constructed with {len(Q)} access words and {len(T)} test words")
