from typing import Any, List, Tuple
from pipelines.prompta.oracle.cot_llm_oracle import CoTLLMOracle
from pipelines.prompta.utils import str2bool
from prompta.core.language import base_language


class VerificationLLMOracle(CoTLLMOracle):

    def _get_membership_query_result(self, input_str: Tuple[str], use_cache: bool=True, seed: int=0, *args: Any, **kwargs: Any) -> Any:
        self.addtional_prompt = ''
        res = super()._get_membership_query_result(input_str, *args, **kwargs)
        result = self._get_teacher_resp(input_str)
        if result['match_input']:
            res = res if result['match_definition'] else not res
        else:
            self.addtional_prompt = result['analysis']
            res = super()._get_membership_query_result(input_str, *args, **kwargs)
        self.llm_resp_cache[input_str]['answer'] = res
        return res

    def _construct_existence_message(self, query: Tuple[str]):
        msg = super()._construct_existence_message(query)
        msg[-1]['content'] += self.addtional_prompt
        return msg

    def _get_teacher_resp(self, input_str, retry: int=10):
        if 'teacher' in self.llm_resp_cache[input_str] and self.llm_resp_cache[input_str]['teacher']['match_input']:
            result = self.llm_resp_cache[input_str]['teacher']
            print("teacher", result)
            return result
        else:
            for i in range(retry):
                teacher_check = self._construct_teacher_message(input_str)
                result = self._get_json_resp(teacher_check, seed=i * 10)
                result = self.get_teacher_check_result(result)
                if result is not None:
                    print("teacher", result)
                    self.llm_resp_cache[input_str]['teacher'] = result
                    return result
        return None

    def get_teacher_check_result(self, result):
        if 'match_definition' in result:
            match_definition = result['match_definition']
        else:
            return None
        if 'match_input' in result:
            match_input = result['match_input']
        else:
            return None
        return {'match_definition': str2bool[match_definition], 'match_input': str2bool[match_input], 'analysis': result['analysis']}

    def _construct_teacher_message(self, query: Tuple[str]):
        return [
            {"role": "system", "content": "You are a helpful assistant designed to output JSON. Answer in a consistent style. "\
             "Your response should follow the format: {'analysis': '[Write your analysis here in detail]', 'match_definition': [true or false], 'match_input': [true or false]}"},
            {"role": "user", "content": f"{self.language.definition}. The sequence \"{query}\" is "\
                                        f"{'' if self.llm_resp_cache[query]['answer'] else 'not'} in the language, because "\
                                        f"{self.llm_resp_cache[query]['reason']}. Does the reason match the language definition and the input sequence?"}
        ]
