from __future__ import annotations
from typing import List, Tuple, Iterable, Optional

from automata.Lstar_utils import (
    MQ, VanillaEQ, lenlex, as_word,
    find_closure_violation, split_counterexample, build_hypothesis_dfa
)
from automata.dfa import DFA
import pdb

class LStar:
    """
    Angluin's L* (access/test words) with MQ + EQ.

    - MQ: either a DFA teacher or a MembershipOracle-backed LM teacher.
    - EQ: 'VanillaEQ' over a held-out labeled set (returns first mismatch).
    - Sigma is fixed in length-lex order.
    """

    def __init__(self, sigma: Iterable[str], mq: MQ, eq: VanillaEQ):
        self.sigma = lenlex(sigma)
        self.mq = mq
        self.eq = eq

    def learn(self) -> DFA:
        """
        Run L* until EQ returns no counterexample on its sample set.
        Returns the final hypothesis DFA (canonical, minimized).
        """
        Q: List[Tuple[str, ...]] = [()]   # access words, must include ε
        T: List[Tuple[str, ...]] = [()]   # test words, must include ε

        # Ensure closure before first hypothesis
        while True:
            q_new = find_closure_violation(Q, T, self.sigma, self.mq)
            if q_new is None:
                break
            if q_new not in Q:
                Q.append(q_new)

        assert find_closure_violation(Q, T, self.sigma, self.mq) is None
        A_hat = build_hypothesis_dfa(Q, T, self.sigma, self.mq)
        print(f"[L*] initial hypothesis constructed")

        # Main loop
        cx = 0
        while True:
            ce = self.eq(A_hat)  # None means "equivalent" on the provided set
            if ce is None:
                print(f"[L*] EQ found no counterexample after {cx} counterexamples.")
                return A_hat, cx

            w = as_word(ce)
            if len(w) == 0:
                raise AssertionError("EQ produced ε as a counterexample; this should be impossible.")

            cx += 1
            print(f"No. {cx} counterexample: " + " ".join(w))
            # Refine until w becomes consistent
            while True:
                y_true = self.mq(w)
                y_hat  = 1 if A_hat.accepts(list(w)) else 0
                if y_true == y_hat:
                    break

                q_prime, t_prime = split_counterexample(Q, T, self.sigma, self.mq, w)
                if q_prime not in Q:
                    Q.append(q_prime)
                if t_prime not in T:
                    T.append(t_prime)

                # Re‑close
                while True:
                    q_new = find_closure_violation(Q, T, self.sigma, self.mq)
                    if q_new is None:
                        break
                    if q_new not in Q:
                        Q.append(q_new)

                assert find_closure_violation(Q, T, self.sigma, self.mq) is None
                A_hat = build_hypothesis_dfa(Q, T, self.sigma, self.mq)
