from typing import Any, List, Dict
from collections import defaultdict
from pipelines.prompta.utils import tuple2word, word2tuple
from prompta.core.language import BaseLanguage
from .base_oracle import BaseOracle


class ProbabilisticAbstractOracle(BaseOracle):

    def __init__(self, language: BaseLanguage, prob: float = 0.1, *args: Any, **kwargs: Any) -> None:
        super().__init__(language, *args, **kwargs)
        self.prob = prob
        self._reset_data()

    def reset(self, language: BaseLanguage, exp_dir: str, alphabet=None, load_history=False):
        super().reset(language, exp_dir, alphabet, load_history)
        self._reset_data()

    def _reset_data(self):
        self.ce2queries = defaultdict(list)
        self.query_buffer = []
        self.last_ce = None
        self.mistake_cache = {}
        self.need_recovery = False

    def __call__(self, input_str: str, *args: Any, **kwargs: Any) -> Any:
        if input_str in self.query_history:
            res = self.query_history[input_str]
            self.query_buffer.append((input_str, res))
            return res
        
        res = self._get_membership_query_result(input_str, *args, **kwargs)
        self.query_history[input_str] = res
        if res != self.language.in_language(input_str):
            self.mistake_cache[input_str] = not res
        self.query_buffer.append((input_str, res))

        return res
    
    def _get_membership_query_result(self, input_str: str, *args: Any, **kwargs: Any) -> Any:
        raise NotImplementedError
    
    def check_conjecture(self, aut, _type=str):
        self.add_buffer()
        ce = self.language.counterexample(aut, _type)
        self.last_ce = ce
        return ce
    
    def update_query_history(self, new_query_history: Dict):
        self.query_history_backup = {k: v for k, v in self.query_history.items()}
        self.query_history = new_query_history
        self.need_recovery = True
    
    def recover_query_history(self):
        if self.need_recovery:
            self.query_history = self.query_history_backup
            self.need_recovery = False

    def add_buffer(self):
        if self.last_ce is not None:
            self.ce2queries[word2tuple(self.last_ce.getInput())] = self.query_buffer
            self.query_buffer = []
