

from typing import Any, List
from pipelines.prompta.oracle.base_llm_oracle import BaseLLMOracle
from pipelines.prompta.utils import get_value_by_key


class CoTLLMOracle(BaseLLMOracle):

    def _get_membership_query_result(self, input_str: str, use_cache: bool=True, seed: int=0, *args: Any, **kwds: Any) -> Any:
        if input_str in self.llm_resp_cache and use_cache:
            return self.llm_resp_cache[input_str]['answer']
        queries = self._construct_existence_message(input_str)
        result = self._get_json_resp(queries, seed=seed)
        ans = get_value_by_key(result, "answer", is_boolean=True)
        rsn = get_value_by_key(result, "reason")
        self.llm_resp_cache[input_str] = {'answer': ans, 'reason': rsn}
        return ans
    
    def _construct_existence_message(self, query: str):
        return [
            {"role": "system", "content": "You are a helpful assistant designed to output JSON. Answer in a consistent style and output the reason first."},
            {"role": "user", "content": f"{self.language.definition}. {self.language.examples['pos']['query']}"},
            {"role": "assistant", "content": self.language.examples['pos']['answer']},
            {"role": "user", "content": f"{self.language.examples['neg']['query']}"},
            {"role": "assistant", "content": self.language.examples['neg']['answer']},
            {"role": "user", "content": f"Given a string \"{query}\", does this string belongs to the language?"}
        ]

