from typing import Any, List
from jpype import JImplements, JOverride
from jpype.types import JArray, JString
from pipelines.prompta.oracle import BaseOracle
from pipelines.prompta.utils import word2tuple
from prompta.utils.java_libs import Alphabets, DFAMembershipOracle, Query, Collections, DefaultQuery, Word


@JImplements(DFAMembershipOracle)
class MembershipOracleWrapper:

    def __init__(
            self,
            oracle: BaseOracle
            ) -> None:
        self.consistent_flag = True
        self.oracle = oracle
        self.me_counter = 0
        self.eq_counter = 0
        self.total_llm_query_count = 0

    def reset(self, *args, **kwargs):
        self.me_counter = 0
        self.eq_counter = 0
        self.total_llm_query_count = 0
        self.oracle.reset(*args, **kwargs)

    def __call__(self, input_str: str, *args: Any, **kwds: Any) -> Any:
        self.me_counter += 1
        input_str = input_str
        res = self.oracle(input_str, *args, **kwds)
        if res != self.oracle.language.in_language(input_str):
            raise RuntimeError("Wrong membership query")
        if hasattr(self.oracle, 'llm_resp_cache'):
            self.total_llm_query_count = len(self.oracle.llm_resp_cache)
        return res
    
    @JOverride 
    def answerQuery(self, input) -> Any:
        return self.answerQuery(Word.epsilon(), input)
    
    @JOverride
    def answerQuery(self, prefix, suffix) -> Any:
        query = DefaultQuery(prefix, suffix)
        self.processQuery(query)
        return query.getOutput()
    
    @JOverride
    def processQuery(self, query):
        self.processQueries(Collections.singleton(query))
    
    @JOverride
    def processQueries(self, queries):
        for q in queries:
            word = word2tuple(q.getInput())
            q.answer(self(word))

    @JOverride
    def asOracle(self):
        return self
    
    @JOverride
    def processBatch(self, batch):
        self.processQueries(batch)

    @property
    def alphabet(self):
        py_list = self.oracle.alphabet
        java_array = JArray(JString)(py_list)
        alphabet = Alphabets.fromArray(java_array)
        return alphabet

    def check_conjecture(self, aut, _type=str):
        self.eq_counter += 1
        counterexample = self.oracle.check_conjecture(aut, _type)
        return counterexample

    def __getattr__(self, name: str) -> Any:
        return getattr(self.oracle, name)
    

class TemporaryMembershipOracleWrapper(MembershipOracleWrapper):

    def __call__(self, input_str: str, *args: Any, **kwds: Any) -> Any:
        self.me_counter += 1
        input_str = input_str
        res = self.oracle(input_str, *args, **kwds)
        if hasattr(self.oracle, 'llm_resp_cache'):
            self.total_llm_query_count = len(self.oracle.llm_resp_cache)
        return res


class ErrorAllowWrapper(MembershipOracleWrapper):

    def __call__(self, input_str: str, *args: Any, **kwds: Any) -> Any:
        self.me_counter += 1
        input_str = input_str
        res = self.oracle(input_str, *args, **kwds)
        if hasattr(self.oracle, 'llm_resp_cache'):
            self.total_llm_query_count = len(self.oracle.llm_resp_cache)
        return res


class ErrorStatisticsWrapper(MembershipOracleWrapper):

    def __call__(self, input_str: str, *args: Any, **kwds: Any) -> Any:
        self.me_counter += 1
        input_str = input_str
        res = self.oracle(input_str, *args, **kwds)
        if hasattr(self.oracle, 'llm_resp_cache'):
            self.total_llm_query_count = len(self.oracle.llm_resp_cache)
        if self.me_counter > 50:
            raise RuntimeError("Timeout!")
        return res


class EpsilonEstimationWrapper(ErrorAllowWrapper):

    def __init__(
            self,
            oracle: BaseOracle
            ) -> None:
        super().__init__(oracle)

    def __call__(self, input_str: str, *args: Any, **kwds: Any) -> Any:
        if len(self.oracle.llm_resp_cache) >= 50:
            raise RuntimeError("Timeout!")
        res = super().__call__(input_str, *args, **kwds)
        return self.oracle.language.in_language(input_str)
    
    @property
    def epsilon(self):
        error_counter = 0
        for k, v in self.llm_resp_cache.items():
            if v['answer'] != self.oracle.language.in_language(k):
                error_counter += 1
        return error_counter / len(self.llm_resp_cache)

