from pipelines.prompta.learner.java_utils.dfa import ExtensibleLStarDFA
from prompta.utils.java_libs import Collections, Word, ObservationTableCEXHandlers, ClosingStrategies
from pipelines.prompta.oracle import BaseOracle
from pipelines.prompta.utils import save_dfa
from .base_learner import BaseLearner


class LStarLearner(BaseLearner):
    ID = "LStar"

    def __init__(self, oracle, exp_dir: str):
        super().__init__(oracle, exp_dir)
        self.learner = ExtensibleLStarDFA(
            self.oracle.jalphabet, 
            oracle,
            Collections.singletonList(Word.epsilon()),
            ObservationTableCEXHandlers.CLASSIC_LSTAR,
            ClosingStrategies.CLOSE_FIRST)

    def learn(self):
        self.learner.startLearning()
        while True:
            hypothesis = self.learner.getHypothesisModel()
            ce = self.oracle.check_conjecture(hypothesis, 'DefaultQuery')
            if ce is None:
                break
            self.learner.refineHypothesis(ce)
        save_dfa(self.get_dfa_save_path(), self.learner.getHypothesisModel(), self.oracle.jalphabet)
        return self.learner.getHypothesisModel()

