from llm_agent import *
from utils import *
from fc_engine import *
import time

prompt_temp_root = './benchmarks/'
prompt_output_root = './benchmarks/'
output_root = './outputs/'

Three_value_set = ['True', 'TRUE', 'true', 'False', 'FALSE', 'false', 'Unknown', 'UNKNOWN', 'unknown']


class ReLLM:
    def __init__(self, api_config, args, context, esl_path, CS):
        self.domain = args.domain
        self.sub_domain = args.sub_domain
        self.api_config = api_config
        self.target_memory = args.target_memory
        self.target_model = TargetAgent(self.api_config, self.target_memory, args.api_timeout)
        self.perception_model = PerceptionAgent(self.api_config, args.api_timeout)
        self.args = args
        self.CS = CS
        self.esl_path = esl_path
        self.esl_json = None
        self.esl_config = None
        self.esl_single_model = None
        self.config = None
        self.task_id = 0

        self.load_spec()  # the only domain_spec file
        self.prompt_temp_config = self.load_prompt_temp(CS)

        self.variables = self.esl_json[0]["variables"]
        self.predicates = self.esl_json[0]["predicates"]

    def set_output_log(self, output):
        self.log_output = output

    def set_task_id(self, id):
        self.task_id = id

    # Todo: remove spare variables (not important)
    def load_spec(self):
        with open(self.esl_path, 'r') as f:
            self.esl_json = json.load(f)

        print(f"\nLength of esl: {len(self.esl_json)}\n")
        output_file = f"{self.args.output}/Case_study_{self.args.CS}_{self.args.domain}/{self.args.sub_domain}/merged_esl/merged_{self.esl_path.split('/')[-1]}"
        self.esl_single_model, self.config = normalize_json(self.esl_json, output_file)

    def load_prompt_temp(self, CS):
        abstr_prompt = None
        questionNewRound_prompt = None
        newInstanceGen_prompt = None
        question_init_prompt = None

        if CS == 1:  # domain=mrt
            prompt_question_init = f'./benchmarks/prompt_temp/{self.domain}/{self.domain}_question_init.txt'
            prompt_abstr_temp = f'./benchmarks/prompt_temp/{self.domain}/{self.domain}_abstract_temp.txt'
            prompt_question_newRound = f'benchmarks/prompt_temp/{self.domain}/{self.domain}_question_newRound.txt'
            with open(prompt_abstr_temp, 'r') as f:
                abstr_prompt = f.read()
                self.abstr_prompt = abstr_prompt
            f.close()

            with open(prompt_question_newRound, 'r') as f:
                questionNewRound_prompt = f.read()
                self.questionNewRound_prompt = questionNewRound_prompt
            f.close()

            self.question_init_prompt = question_init_prompt
            self.abstr_prompt = abstr_prompt
            self.questionNewRound_prompt = questionNewRound_prompt
            self.newInstanceGen_prompt = newInstanceGen_prompt

        elif CS == 2:  # domain=comp, interLevel=2
            prompt_abstr_temp = f'./benchmarks/prompt_temp/{self.domain}/{self.domain}_abstract_temp.txt'
            prompt_newInstanceGen = f'./benchmarks/prompt_temp/{self.domain}/{self.domain}_newInstanceGen.txt'
            prompt_question_newRound = f'./benchmarks/prompt_temp/{self.domain}/{self.domain}_question_newRound.txt'
            with open(prompt_abstr_temp, 'r') as f:
                abstr_prompt = f.read()
                self.abstr_prompt = abstr_prompt
            f.close()

            with open(prompt_question_newRound, 'r') as f:
                questionNewRound_prompt = f.read()
                self.questionNewRound_prompt = questionNewRound_prompt
            f.close()

            with open(prompt_newInstanceGen, 'r') as f:
                newInstanceGen_prompt = f.read()
                self.newInstanceGen_prompt = newInstanceGen_prompt
            f.close()

            self.question_init_prompt = question_init_prompt
            self.abstr_prompt = abstr_prompt
            self.questionNewRound_prompt = questionNewRound_prompt
            self.newInstanceGen_prompt = newInstanceGen_prompt

        elif CS == 3:  # domain=ineq, interLevel=1
            prompt_abstr_temp = f'./benchmarks/prompt_temp/{self.domain}/{self.domain}_abstract_temp_{self.sub_domain}.txt'
            prompt_question_newRound = f'benchmarks/prompt_temp/{self.domain}/{self.domain}_question_newRound_{self.sub_domain}.txt'
            with open(prompt_abstr_temp, 'r') as f:
                abstr_prompt = f.read()
                self.abstr_prompt = abstr_prompt
            f.close()

            with open(prompt_question_newRound, 'r') as f:
                questionNewRound_prompt = f.read()
                self.questionNewRound_prompt = questionNewRound_prompt
            f.close()

            self.question_init_prompt = question_init_prompt
            self.abstr_prompt = abstr_prompt
            self.questionNewRound_prompt = questionNewRound_prompt
            self.newInstanceGen_prompt = newInstanceGen_prompt

        else:
            exit(0)

        promt_temp_config = {'abstr_prompt': self.abstr_prompt,
                             'questionNewRound_prompt': self.questionNewRound_prompt,
                             'newInstanceGen_prompt': self.newInstanceGen_prompt,
                             'question_init_prompt': self.question_init_prompt}
        return promt_temp_config

    def generate_full_prompt_abstr(self, context, if_outputFile=False):
        variables = ', '.join(self.variables)
        predicates = '\n\t'.join(self.predicates)
        full_prompt = self.abstr_prompt.replace('[[Variables]]', variables).replace('[[Predicates]]',
                                                                                    predicates).replace(
            '[[Context]]', context)

        return full_prompt

    def generate_full_prompt_questionNewRound(self, context, new_prop):
        full_prompt = self.questionNewRound_prompt.replace('[[Context]]', context).replace('[[Predicates]]',
                                                                                           '\n\t'.join(
                                                                                               self.predicates)).replace(
            '[[Proposition]]', new_prop)

        return full_prompt

    def generate_abstract_result(self, context, file,
                                 if_print=False):  # obtain abstraction result from Perception : object and interpretation
        full_prompt = self.generate_full_prompt_abstr(context=context, if_outputFile=True)
        response, status = self.perception_model.generate_response(full_prompt, temp=0.5)
        try:
            assert status == True
        except:
            error_info = f"\n!!! APIError: {response}"
            return None, None, error_info

        if self.CS == 1 or self.CS == 2:
            matches = re.findall(r'\{(.*?)\}', response, re.DOTALL)
        else:
            matches = re.findall(r'@(.*?)@', response, re.DOTALL)

        try:
            assert len(matches) == 2
            """
                    object: ['group of friends', 'sparklers', 'train platform', 'security officer']
                    interpretation: ['Smoking(group of friends) = Unknown', 
                            'WithLightedItem(group of friends) = True', 
                            'SmokingProhibited(train platform) = True', 
                            'InRailwayPremises(group of friends, train platform) = True', 
                            'IsPremise(train platform) = True']

            """
            objects_all = [i.strip() for i in matches[0].split(',')]
            inter_match = re.sub(r"true", "True", matches[1], flags=re.IGNORECASE)
            inter_match = re.sub(r"false", "False", inter_match, flags=re.IGNORECASE)
            inter_match = re.sub(r"unknown", "Unknown", inter_match, flags=re.IGNORECASE)
            interPre_all = re.findall(r"\w+\(.*?\)\s*=\s*(?:True|False|Unknown)", inter_match, flags=re.IGNORECASE)

            if if_print:
                file.write(f"\n\tobjects_all: {objects_all}")
                file.write(f"\n\tinterPre_all: {interPre_all}")
            return objects_all, interPre_all, ''
        except:
            error_info = f"\n!!! The response is: {response}\n\nType-2 Error in Stage-2: Abstraction result format error by the perception LLM."
            return None, None, error_info

    def run_test(self, domainVec, context, CS, if_print=False):
        """
        :param domainVec: [domain, sub_domain]
        :param context:
        :param CS: three case studies ID: 1,2,3
        :param if_print:
        :return: 0: RV success; 1: RV fail
        """
        # init context for perception agent
        check_fail = f"\n\nSome Type Error is detected.\n\nThe RV fails."
        perception_context = context
        domain = domainVec[0]
        sub_domain = domainVec[1]

        if CS == 1:
            os.makedirs(os.path.dirname(self.log_output), exist_ok=True)
            f = open(self.log_output, 'w', encoding='utf-8')
            model_info = f"{{Target LLM: {self.target_model.model_name}, Perception LLM: {self.perception_model.model_name}}}\n"
            task_info = f"{{ESL file: {self.esl_path}, Task ID: {self.task_id}}}\n\n"
            start_info = f"************* We now start to analysis a new case with domain {domain} and Task id {self.task_id} *************\nContext: {context}\n\n######## Stage 1: Target LLM's init query starts. ########\n"
            print('\n\n' + model_info)
            print(task_info)
            print(start_info)
            f.write(model_info)
            f.write(task_info)
            f.write(start_info)

            target_init_query_temp_path = f'{prompt_temp_root}prompt_temp/{self.domain}/{self.domain}_question_init.txt'
            start_tar_1 = time.time()
            with open(target_init_query_temp_path, 'r') as init_query_file:
                target_init_query_temp = init_query_file.read()
            init_query_file.close()

            target_init_full_query = target_init_query_temp.replace('[[Context]]', context)
            target_response, status = self.target_model.generate_response(target_init_full_query)
            try:
                assert status == True
            except:
                error_info = f"\n!!! APIError: {target_response}"
                print(error_info)
                f.write(error_info)
                Error_type = f"\nFailure Type = 8, Failure Stage = 1"
                print(Error_type)
                f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()
                return None, 1

            if target_response.lower() == "true":
                target_response = True
            elif target_response.lower() == "false":
                target_response = False

            else:
                error_info = f"\n!!! Type-1 Error in Stage-1: Initial round response format error by the target LLM. Expected output: '{'True, False'}'. Returned output: {target_response}."
                print(error_info)
                f.write(error_info)

                Error_type = f"\nFailure Type = 1, Failure Stage = 1"
                print(Error_type)
                f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()

                return None, 1

            end_tar_1 = time.time()
            response_info = f"\nThe response from the target LLM: {target_response}. (True: Has no misconduct; False: Has some misconduct)\n"
            time_info = f"\n######## Stage 1: Target LLM's init query takes {end_tar_1 - start_tar_1} seconds. ########\n\n######## Stage 2: Perception LLM's abstraction starts. ########\n"
            print(response_info)
            print(time_info)
            f.write(response_info)
            f.write(time_info)

            # Perception Agent start
            per_start = time.time()
            objects_all, interPre_all, error_info = self.generate_abstract_result(context, f, if_print=if_print)
            per_end = time.time()

            print("Objects_all:", objects_all)
            print("interPre_all:", interPre_all)
            perception_abstract_time = per_end - per_start
            time_info = f"\n\n######## Stage 2: Perception LLM's abstraction takes {perception_abstract_time} seconds. ########\n\n######## Stage 3: FC graph initiation starts. ########\n"
            if objects_all == None:
                print(error_info)
                f.write(error_info)
                if 'APIError' in error_info:
                    Error_type = f"\nFailure Type = 8, Failure Stage = 2"
                else:
                    Error_type = f"\nFailure Type = 2, Failure Stage = 2"

                print(Error_type)
                f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()

                return None, 1

            print(time_info)
            f.write(time_info)

            # init graph
            fc_graph = FCGraph(self.domain, objects_all, interPre_all, self.esl_single_model, self.args,
                               self.prompt_temp_config, if_print)

            fc_graph.init_perpcetion_model(self.perception_model)
            fc_graph.init_target_model(self.target_model)
            if if_print:
                fc_graph.set_outFile(f)

            init_graph_start = time.time()
            res = fc_graph.init()
            init_graph_end = time.time()
            if res == 1:
                empty_error = False
                error_info = fc_graph.error_info
                Error_type = None
                if error_info == '':
                    empty_error = True
                    error_info = "\nSome error happens during graph_init (return 1)"
                    Error_type = f"\nFailure Type = 9, Failure Stage = 3"

                print(error_info)
                f.write(error_info)
                if empty_error:
                    print(Error_type)
                    f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()
                fc_graph.delete()
                return None, 1

            time_info = f"\n\n######## Stage 3: FC graph initiation takes {init_graph_end - init_graph_start} seconds. ########\n\n######## Stage 4: Forward chaining starts. ########\n"
            print(time_info)
            f.write(time_info)

            fc_start = time.time()
            try:
                if_consistent = fc_graph.forward_chain(opt=0)  # early stop once an inconsistency is detected.
                # if_consistent = fc_graph.forward_chain(opt=1) # check all edges and all paths
            except:
                error_info = f"\n!!! Type-3 Error in Stage-4: Code logic error in forward chaining procedure."
                print(error_info)
                f.write(error_info)

                Error_type = f"\nFailure Type = 9, Failure Stage = 4"
                print(Error_type)
                f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()
                fc_graph.delete()
                return None, 1

            fc_end = time.time()
            time_info = f"\n\n######## Stage 4: Forward chaining takes {fc_end - fc_start} seconds. ########\n\n######## Stage 5: Query the target LLM about new inferred knowledge starts. ########\n"
            print(time_info)
            f.write(time_info)

            # Start to check the consistency to the inferred ground truth
            new_inferred_check_start = time.time()
            if if_consistent:
                for new_knowledge in fc_graph.new_inferred_node:
                    label = fc_graph.graph.nodes[new_knowledge]["label"]
                    bool_value = (label[0] != '~')
                    if not bool_value:
                        label_clean = label[1:].split('(')[0].split('_')[0] + '(' + label[1:].split('(')[1]
                    else:
                        label_clean = label.split('(')[0].split('_')[0] + '(' + label.split('(')[1]

                    queryGen_full_prompt = self.generate_full_prompt_questionNewRound(context, label_clean)
                    target_newRound_response, status = self.perception_model.generate_response(queryGen_full_prompt,
                                                                                               temp=0.5)

                    try:
                        assert status == True
                    except:
                        error_info = f"\n!!! APIError: {target_newRound_response}"
                        print(error_info)
                        f.write(error_info)
                        Error_type = f"\nFailure Type = 8, Failure Stage = 5"
                        print(Error_type)
                        f.write(Error_type)
                        print(check_fail)
                        f.write(check_fail)
                        f.close()
                        fc_graph.delete()
                        return None, 1

                    try:
                        assert target_newRound_response.lower() in Three_value_set

                    except:
                        error_info = f"\n!!! Type-1 Error in Stage-5: Only True, False, Unknown values are accepted from target LLM, but obtained {target_newRound_response}."
                        print(error_info)
                        f.write(error_info)

                        Error_type = f"\nFailure Type = 4, Failure Stage = 5"
                        print(Error_type)
                        f.write(Error_type)
                        print(check_fail)
                        f.write(check_fail)
                        f.close()
                        fc_graph.delete()
                        return None, 1

                    if target_newRound_response == str(bool_value):
                        continue
                    else:
                        print("Target LLM returns a wrong answer for the new query.")
                        f.write("\n\tTarget LLM returns a wrong answer for the new query.")
                        if_consistent = False
                        break

            new_inferred_check_end = time.time()
            time_info = f"\n\n######## Stage 5: Query the target LLM about new inferred knowledge takes {new_inferred_check_end - new_inferred_check_start} seconds. ########\n"
            print(time_info)
            f.write(time_info)

            res_1 = 'Correct' if if_consistent == target_response else 'Wrong'
            res_2 = 'same as' if (if_consistent == target_response) else 'different from'

            check_res = f"\nEvaluation Results: {res_1} {{ (LLM: {target_response}, ESL (inter_level={self.args.inter_level}): {if_consistent} }}.\n\tThe target LLM's analysis result is {res_2} the result analyzed by the ESL file."
            if res_1 == 'Wrong' and if_consistent:
                check_res += "\n\tHowever, no inconsistency is detected by the ESL file, suggesting that additional rules may be required or perception agent fail to abstract good enough interpretations."

            total_time = f"\n\nThe total time of ReLLM for task {self.task_id}: {time.time() - start_tar_1} seconds.\n"
            Error_type = f"\nFailure Type = 0, Failure Stage = 0\n\nLLM: {target_response}, FC: {if_consistent}\n\nThe RV successes!"
            print(check_res)
            print(total_time)
            print(Error_type)
            f.write(check_res)
            f.write(total_time)
            f.write(Error_type)
            f.close()
            fc_graph.delete()
            return if_consistent, 0

        elif CS == 2:  # comparison
            os.makedirs(os.path.dirname(self.log_output), exist_ok=True)
            f = open(self.log_output, 'w', encoding='utf-8')

            model_info = f"{{Target LLM: {self.target_model.model_name}, Perception LLM: {self.perception_model.model_name}}}\n"
            task_info = f"{{ESL file: {self.esl_path}, Task ID: {self.task_id}}}\n\n"

            start_info = f"************* We now start to analysis a new case with domain {domain} and Task id {self.task_id} *************\nContext: {context}\n\n######## Stage 1: Target LLM's init query starts. ########\n"
            print('\n\n' + model_info)
            print(task_info)
            print(start_info)
            f.write(model_info)
            f.write(task_info)
            f.write(start_info)

            start_tar_1 = time.time()
            context_require = " Please output the result directly with nothing else."
            target_response, status = self.target_model.generate_response(context + context_require)

            try:
                assert status == True
            except:
                error_info = f"\n!!! APIError: {target_response}"
                print(error_info)
                f.write(error_info)
                Error_type = f"\nFailure Type = 8, Failure Stage = 1"
                print(Error_type)
                f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()
                return None, 1

            full_context = f"{context} {target_response}."

            end_tar_1 = time.time()
            response_info = f"\nThe response from the target LLM: {target_response}. \n"
            time_info = f"\n######## Stage 1: Target LLM's init query takes {end_tar_1 - start_tar_1} seconds. ########\n\n######## Stage 2: Perception LLM's abstraction starts. ########\n"
            print(response_info)
            print(time_info)
            f.write(response_info)
            f.write(time_info)

            # Perception Agent start
            per_start = time.time()
            objects_all, interPre_all, error_info = self.generate_abstract_result(full_context, f, if_print=if_print)

            per_end = time.time()

            print("Objects_all:", objects_all)
            print("interPre_all:", interPre_all)
            perception_abstract_time = per_end - per_start
            time_info = f"\n\n######## Stage 2: Perception LLM's abstraction takes {perception_abstract_time} seconds. ########\n\n######## Stage 3: FC graph initiation starts. ########\n"
            if objects_all == None:
                print(error_info)
                f.write(error_info)
                if 'APIError' in error_info:
                    Error_type = f"\nFailure Type = 8, Failure Stage = 2"
                else:
                    Error_type = f"\nFailure Type = 2, Failure Stage = 2"

                print(Error_type)
                f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()

                return None, 1

            print(time_info)
            f.write(time_info)

            # init graph
            fc_graph = FCGraph(self.domain, objects_all, interPre_all, self.esl_single_model, self.args,
                               self.prompt_temp_config, if_print)

            fc_graph.init_perpcetion_model(self.perception_model)
            fc_graph.init_target_model(self.target_model)
            if if_print:
                fc_graph.set_outFile(f)

            init_graph_start = time.time()
            res = fc_graph.init()
            init_graph_end = time.time()
            if res == 1:
                empty_error = False
                error_info = fc_graph.error_info
                Error_type = None
                if error_info == '':
                    empty_error = True
                    error_info = "\nSome error happens during graph_init (return 1)"
                    Error_type = f"\nFailure Type = 9, Failure Stage = 3"

                print(error_info)
                f.write(error_info)
                if empty_error:
                    print(Error_type)
                    f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()
                fc_graph.delete()
                return None, 1

            time_info = f"\n\n######## Stage 3: FC graph initiation takes {init_graph_end - init_graph_start} seconds. ########\n\n######## Stage 4: Forward chaining starts. ########\n"
            print(time_info)
            f.write(time_info)

            fc_start = time.time()
            try:
                if_consistent = fc_graph.forward_chain(opt=0)  # early stop once an inconsistency is detected.
                # if_consistent = fc_graph.forward_chain(opt=1)
            except:
                error_info = f"\n!!! Type-3 Error in Stage-4: Code logic error in forward chaining procedure."
                print(error_info)
                f.write(error_info)

                Error_type = f"\nFailure Type = 9, Failure Stage = 4"
                print(Error_type)
                f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()
                fc_graph.delete()
                return None, 1

            fc_end = time.time()
            time_info = f"\n\n######## Stage 4: Forward chaining takes {fc_end - fc_start} seconds. ########\n\n######## Stage 5: Query the target LLM about new inferred knowledge starts. ########\n"
            print(time_info)
            f.write(time_info)

            # Start to check the consistency to the inferred ground truth
            new_inferred_check_start = time.time()
            if if_consistent:
                for new_knowledge in fc_graph.new_inferred_node:
                    label = fc_graph.graph.nodes[new_knowledge]["label"]
                    bool_value = (label[0] != '~')
                    if not bool_value:
                        label_clean = label[1:].split('(')[0].split('_')[0] + '(' + label[1:].split('(')[1]
                    else:
                        label_clean = label.split('(')[0].split('_')[0] + '(' + label.split('(')[1]

                    queryGen_full_prompt = self.generate_full_prompt_questionNewRound(context, label_clean)
                    target_newRound_response, status = self.perception_model.generate_response(queryGen_full_prompt,
                                                                                               temp=0.5)
                    try:
                        assert status == True
                    except:
                        error_info = f"\n!!! APIError: {target_newRound_response}"
                        print(error_info)
                        f.write(error_info)
                        Error_type = f"\nFailure Type = 8, Failure Stage = 5"
                        print(Error_type)
                        f.write(Error_type)
                        print(check_fail)
                        f.write(check_fail)
                        f.close()
                        fc_graph.delete()
                        return None, 1

                    try:
                        assert target_newRound_response.lower() in Three_value_set

                    except:
                        error_info = f"\n!!! Type-1 Error in Stage-5: Only True, False, Unknown values are accepted from target LLM, but obtained {target_newRound_response}."
                        print(error_info)
                        f.write(error_info)

                        Error_type = f"\nFailure Type = 4, Failure Stage = 5"
                        print(Error_type)
                        f.write(Error_type)
                        print(check_fail)
                        f.write(check_fail)
                        f.close()
                        fc_graph.delete()
                        return None, 1

                    if target_newRound_response == str(bool_value):
                        continue
                    else:
                        print("Target LLM returns a wrong answer for the new query.")
                        f.write(
                            f"\n\tTarget LLM returns a wrong answer for the new query: {label_clean} = {target_newRound_response}.")
                        if_consistent = False
                        break

            new_inferred_check_end = time.time()
            time_info = f"\n\n######## Stage 5: Query the target LLM about new inferred knowledge takes {new_inferred_check_end - new_inferred_check_start} seconds. ########\n"
            print(time_info)
            f.write(time_info)

            res_1 = 'Correct' if if_consistent else 'Wrong'

            check_res = f"\nEvaluation Results: {res_1} {{ (LLM: {full_context}, ESL (inter_level={self.args.inter_level}): {if_consistent} }}."

            total_time = f"\n\nThe total time of ReLLM for task {self.task_id}: {time.time() - start_tar_1} seconds.\n"
            Error_type = f"\nFailure Type = 0, Failure Stage = 0\n\nLLM: {target_response}, FC: {if_consistent}\n\nThe RV successes!"
            print(check_res)
            print(total_time)
            print(Error_type)
            f.write(check_res)
            f.write(total_time)
            f.write(Error_type)
            f.close()
            fc_graph.delete()
            return if_consistent, 0

        elif CS == 3:
            os.makedirs(os.path.dirname(self.log_output), exist_ok=True)
            f = open(self.log_output, 'w', encoding='utf-8')

            # model_info = f"{{Target LLM: {self.target_model.model_name}, Perception LLM: {self.perception_model.model_name}}}\n"
            # task_info = f"{{ESL file: {self.esl_path}, Task ID: {self.task_id}}}\n\n"
            #
            # start_info = f"************* We now start to analysis a new case with domain {domain} and Task id {self.task_id} *************\nContext: {context}\n\n######## Stage 1: Target LLM's init query starts. ########\n"

            time_info = f"\n######## Stage 1: Target LLM's init query takes 0 seconds (offline). ########\n\n######## Stage 2: Perception LLM's abstraction starts. ########\n"
            print(time_info)
            f.write(time_info)

            # Perception Agent start
            full_context = context
            per_start = time.time()
            objects_all, interPre_all, error_info = self.generate_abstract_result(full_context, f, if_print=if_print)

            per_end = time.time()

            print("Objects_all:", objects_all)
            print("interPre_all:", interPre_all)
            perception_abstract_time = per_end - per_start
            time_info = f"\n\n######## Stage 2: Perception LLM's abstraction takes {perception_abstract_time} seconds. ########\n\n######## Stage 3: FC graph initiation starts. ########\n"
            if objects_all == None:
                print(error_info)
                f.write(error_info)
                if 'APIError' in error_info:
                    Error_type = f"\nFailure Type = 8, Failure Stage = 2"
                else:
                    Error_type = f"\nFailure Type = 2, Failure Stage = 2"

                print(Error_type)
                f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()

                return None, 1

            print(time_info)
            f.write(time_info)

            # init graph
            fc_graph = FCGraph(self.domain, objects_all, interPre_all, self.esl_single_model, self.args,
                               self.prompt_temp_config, if_print)

            fc_graph.init_perpcetion_model(self.perception_model)
            fc_graph.init_target_model(self.target_model)
            if if_print:
                fc_graph.set_outFile(f)

            init_graph_start = time.time()
            res = fc_graph.init()
            init_graph_end = time.time()
            if res == 1:
                empty_error = False
                error_info = fc_graph.error_info
                Error_type = None
                if error_info == '':
                    empty_error = True
                    error_info = "\nSome error happens during graph_init (return 1)"
                    Error_type = f"\nFailure Type = 9, Failure Stage = 3"

                print(error_info)
                f.write(error_info)
                if empty_error:
                    print(Error_type)
                    f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()
                fc_graph.delete()
                return None, 1

            time_info = f"\n\n######## Stage 3: FC graph initiation takes {init_graph_end - init_graph_start} seconds. ########\n\n######## Stage 4: Forward chaining starts. ########\n"
            print(time_info)
            f.write(time_info)

            fc_start = time.time()
            try:
                if_consistent = fc_graph.forward_chain(opt=0)  # early stop once an inconsistency is detected.
                # if_consistent = fc_graph.forward_chain(opt=1)
            except:
                error_info = f"\n!!! Type-3 Error in Stage-4: Code logic error in forward chaining procedure."
                print(error_info)
                f.write(error_info)

                Error_type = f"\nFailure Type = 9, Failure Stage = 4"
                print(Error_type)
                f.write(Error_type)
                print(check_fail)
                f.write(check_fail)
                f.close()
                fc_graph.delete()
                return None, 1

            fc_end = time.time()
            time_info = f"\n\n######## Stage 4: Forward chaining takes {fc_end - fc_start} seconds. ########\n\n######## Stage 5: Query the target LLM about new inferred knowledge starts. ########\n"
            print(time_info)
            f.write(time_info)

            # Start to check the consistency to the inferred ground truth
            new_inferred_check_start = time.time()
            if if_consistent:
                for new_knowledge in fc_graph.new_inferred_node:
                    label = fc_graph.graph.nodes[new_knowledge]["label"]
                    bool_value = (label[0] != '~')
                    if not bool_value:
                        label_clean = label[1:].split('(')[0].split('_')[0] + '(' + label[1:].split('(')[1]
                    else:
                        label_clean = label.split('(')[0].split('_')[0] + '(' + label.split('(')[1]

                    queryGen_full_prompt = self.generate_full_prompt_newRound(context, label_clean)
                    target_newRound_response, status = self.perception_model.generate_response(queryGen_full_prompt,
                                                                                               temp=0.5)
                    try:
                        assert status == True
                    except:
                        error_info = f"\n!!! APIError: {target_newRound_response}"
                        print(error_info)
                        f.write(error_info)
                        Error_type = f"\nFailure Type = 8, Failure Stage = 5"
                        print(Error_type)
                        f.write(Error_type)
                        print(check_fail)
                        f.write(check_fail)
                        f.close()
                        fc_graph.delete()
                        return None, 1

                    try:
                        assert target_newRound_response.lower() in Three_value_set

                    except:
                        error_info = f"\n!!! Type-1 Error in Stage-5: Only True, False, Unknown values are accepted from target LLM, but obtained {target_newRound_response}."
                        print(error_info)
                        f.write(error_info)

                        Error_type = f"\nFailure Type = 4, Failure Stage = 5"
                        print(Error_type)
                        f.write(Error_type)
                        print(check_fail)
                        f.write(check_fail)
                        f.close()
                        fc_graph.delete()
                        return None, 1

                    if target_newRound_response == str(bool_value):
                        continue
                    else:
                        print(
                            f"Target LLM returns a wrong answer for the new query: {label_clean} = {target_newRound_response}, it should be {bool_value}")
                        f.write(
                            f"\n\tTarget LLM returns a wrong answer for the new query: {label_clean} = {target_newRound_response}, it should be {bool_value}.")
                        if_consistent = False
                        break

            new_inferred_check_end = time.time()
            time_info = f"\n\n######## Stage 5: Query the target LLM about new inferred knowledge takes {new_inferred_check_end - new_inferred_check_start} seconds. ########\n"
            print(time_info)
            f.write(time_info)

            res_1 = 'Correct' if if_consistent else 'Wrong'

            check_res = f"\nEvaluation Results: {res_1} {{ (LLM: None, ESL (inter_level={self.args.inter_level}): {if_consistent} }}."

            total_time = f"\n\nThe total time of ReLLM for task {self.task_id}: {time.time() - per_start} seconds.\n"
            Error_type = f"\nFailure Type = 0, Failure Stage = 0\n\nLLM: Unknown, FC: {if_consistent}\n\nThe RV successes!"
            print(check_res)
            print(total_time)
            print(Error_type)
            f.write(check_res)
            f.write(total_time)
            f.write(Error_type)
            f.close()
            fc_graph.delete()
            return if_consistent, 0

        else:
            return True, 0
