from csv import excel_tab
import os
import sys
import json
from json_repair import repair_json
from tqdm import tqdm
from copy import deepcopy
from typing import Dict, List


from agent.nesy_agent.nesy_utils.base_func import func_dict
from agent.nesy_agent.prompts.PROMPTS import *
from agent.nesy_agent.nesy_utils.ast_checker import HardLogicPyChecker

from agent.nesy_agent.nesy_utils.tag_enum import Attraction_Tags, Hotel_Tags

class NL2SL_INSTRUCTION:
    """
    nature language to hard logic prompts
    """
    def __init__(self,locale):
        self.locale = locale
        self.attraction_tags = Attraction_Tags[self.locale]
        self.hotel_tags = Hotel_Tags

    def step1_format(self, nature_language):
        final_prompt = (
           nl2sl_prompt.replace("{attraction_tags}", str(self.attraction_tags)).replace("{hotel_tags}", str(self.hotel_tags))
            + step1_nl2sl_example
            + step1_nl2sl_example_1
            + step1_nl2sl_example_2
            + step1_nl2sl_example_3
            + step1_nl2sl_example_4
            + step1_nl2sl_example_5
            + "\nExamples End."
            + "\nnature_language: "
            + nature_language
            + "\nlogical_constraints: "
            + nature_language
            + "\n"
        )
        return final_prompt
    
    def step2_format(self, hard_logic, user_query):
        final_prompt = (
            sl_trans_prompt.replace("{attraction_tags}", str(self.attraction_tags)).replace("{hotel_tags}", str(self.hotel_tags))
            + hard_logic + "\n The query is: \n" + user_query + "\nanswer:\n"
        )
        return final_prompt
    
    def step3_format(self, query, run_error_list, value_error_list):
        final_prompt = (
            reflect_prompt.replace("{attraction_tags}", str(self.attraction_tags)).replace("{hotel_tags}", str(self.hotel_tags))
            + str(query["hard_logic_py"])
            + "The error is: "
            + "\n".join(run_error_list)
            + "\n".join(value_error_list)
            + "\nThe query is: \n"
            + query["userQuery"]
            + "\nanswer:\n"
        )
        return final_prompt



class NL2SLTranslator:
    """
    natural language to symbolic language translator class
    convert natural language query to symbolic logic constraints
    """
    
    def __init__(self, backbone_llm):
        """
        initialize the translator
        
        Args:
            backbone_llm: large language model instance
            cache_dir: cache directory
        """
        self.backbone_llm = backbone_llm
        # example plan
        self.example_plans = self._load_example_plans()
    
    def _load_example_plans(self, example_plans_dir="temp_plan.json"):
        """load example plan"""
        plan_for_test = {}
        # if the path is not an absolute path, then relative to the current file directory
        if not os.path.isabs(example_plans_dir):
            example_plans_dir = os.path.join(os.path.dirname(__file__), example_plans_dir)
        
        with open(example_plans_dir, "r", encoding="utf-8") as f:
            data = json.load(f)
            plan_for_test = data 
        return plan_for_test
    
    def _get_first_list_in_str(self, json_str):
        """extract the first valid list from the string"""
        json_str = repair_json(json_str, ensure_ascii=False)
        st = 0
        while st < len(json_str) and json_str[st] != "[":
            st += 1
        json_str = json_str[st:]
        stack = []
        for i, c in enumerate(json_str):
            if c == "[":
                stack.append(i)
            elif c == "]":
                stack.pop()
                if not stack:
                    res = json_str[: i + 1]
                    return res
        return "[]"
    
    def nl2sl_step1(self, query):
        """first step: convert natural language to hard logic"""
        nature_language = query['userQuery']
        messages = NL2SL_INSTRUCTION(query['locale']).step1_format(nature_language)
        query_ = self.backbone_llm(messages)
        try:
            l_ptr = query_.find("{")
            r_ptr = query_.rfind("}")
            if l_ptr != -1 and r_ptr != -1:
                query_ = query_[l_ptr : r_ptr + 1]
            
            query_ = json.loads(query_)
            for key in query_:
                query[key] = query_[key]
        except Exception as e:
            query["hard_logic"] = []
            return query

        return query
    
    def nl2sl_step2(self, query):
        """second step: convert hard logic to Python code"""
        try:
            query["hard_logic"] = [str(hl) for hl in query["hard_logic"]]
            hard_logic = "\n".join(query["hard_logic"])
        except Exception as e:
            query["hard_logic"] = []
            query["hard_logic_py"] = []
            return query
            
        messages = NL2SL_INSTRUCTION(query['locale']).step2_format(hard_logic, query['userQuery'])
        
        hard_logic_py = self.backbone_llm(messages)
        example_hard_logic_py = """
        [
            "result=(day_count(plan)==3)",
            "attraction_type_set = set()\nfor activity in allactivities(plan):\n    if activity_type(activity)=='poi':\n        for type in attraction_type(activity, query_i).split(','):\n            attraction_type_set.add(type.strip())\nresult=({'自然风光', '文化体验', '历史建筑'} <= attraction_type_set)",
            "result=(is_chronological(plan,query_i)==1)",
            "result=(is_consistent_city(plan,query_i)==1)",
            "result=(is_open_time(plan,query_i)==1)",
            "result=(is_blockout_time(plan,query_i)==1)",
            "result=(is_early_transport(plan,query_i)==1)",
            "result=(is_consistent_transportation(plan,query_i)==1)"
        ]
        """ 
        hard_logic_py = self._get_first_list_in_str(hard_logic_py)
        
        try:
            query["hard_logic_py"] = json.loads(hard_logic_py)
        except Exception as e:
            query["error_hard_logic_py"] = hard_logic_py
            query["hard_logic_py"] = []
        query["hard_logic_py"] = [str(item) for item in query["hard_logic_py"]]
        query["hard_logic_py"] = list(set(query["hard_logic_py"]))
        return query
    
    def _check(self, query):
        """check if the code block can be executed correctly"""
        run_error_list = []
        run_error_idx = []
        hard_logic_py = query["hard_logic_py"]
        try:
            example_plan = self.example_plans[str(query["day"])]
        except Exception as e:
            example_plan = self.example_plans[str(3)]
        
        for idx, constraint in enumerate(hard_logic_py):
            vars_dict = deepcopy(func_dict)
            vars_dict["plan"] = example_plan["itinerary"]
            vars_dict["query_i"] = example_plan["poi_dict"]
            try:
                exec(
                    constraint,
                    {
                        "__builtins__": {   
                            "set": set,
                        }
                    },
                    vars_dict,
                )
            except Exception as e:
                if str(e) not in [
                    "Failed to create Point instance from string: unknown format.",
                ]:
                    run_error_list.append(str(e))
                    run_error_idx.append(idx)
        return run_error_list, run_error_idx

    def _reflect_info(self, query, checker):
        """get the reflection information"""
        hard_logic_py = query["hard_logic_py"]

        run_error_list, run_error_idx = self._check(query)
        if len(run_error_list):
            return run_error_list, run_error_idx, [], []
        value_error_list = [checker.check(constraint)[0] for constraint in hard_logic_py]
        value_error_idx = [idx for idx, item in enumerate(value_error_list) if len(item)]
        value_error_list = [item for sublist in value_error_list for item in sublist]
        return run_error_list, run_error_idx, value_error_list, value_error_idx
    
    def _reflect(self, query, run_error_list, value_error_list):
        """reflect and fix the code block"""
        content = NL2SL_INSTRUCTION(query['locale']).step3_format(query, run_error_list, value_error_list)
        
        max_retries = 3  # add the maximum number of retries
        retry_count = 0
        while retry_count < max_retries:
            try:
                res = self.backbone_llm(content)
                res = self._get_first_list_in_str(res)
                break
            except Exception as e:
                retry_count += 1
                if retry_count >= max_retries:
                    raise Exception(f"LLM call failed after {max_retries} retries: {e}")
                continue
        

        try:
            query["hard_logic_py"] = json.loads(res)
        except Exception as e:
            query["error_hard_logic_py"] = res
            query["hard_logic_py"] = []
        query["hard_logic_py"] = [str(item) for item in query["hard_logic_py"]]
        return query, len(run_error_list + value_error_list) == 0
    
    def nl2sl_step3(self, query, max_trails=5):
        """third step: reflect and fix the code block"""
        checker = HardLogicPyChecker(query["locale"])
        
        cnt = 0
        query["reflect_info"] = []
        query["hard_logic_py_ood"] = []
        value_error_idx = []
        run_error_idx = []
        
        while cnt < max_trails:
            run_error_list, run_error_idx, value_error_list, value_error_idx = self._reflect_info(
                query, checker
            )
            query["reflect_info"].append(
                {
                    "cnt": cnt,
                    "run_error_list": run_error_list,
                    "value_error_list": value_error_list,
                    "hard_logic_py": query["hard_logic_py"],
                }
            )
            flag = len(run_error_list + value_error_list) == 0
            if flag:
                break
            query, _ = self._reflect(query, run_error_list, value_error_list)
            query["hard_logic_py"] = list(set(query["hard_logic_py"]))
            cnt += 1
            
        query["reflect_cnt"] = cnt
        run_error_list, run_error_idx, value_error_list, value_error_idx = self._reflect_info(
            query, checker
        )
        query["reflect_info"].append(
            {
                "cnt": cnt,
                "run_error_list": run_error_list,
                "value_error_list": value_error_list,
                "hard_logic_py": query["hard_logic_py"],
            }
        )
        error_indices = set(run_error_idx + value_error_idx)
        query["hard_logic_py"] = [
            val
            for idx, val in enumerate(query["hard_logic_py"])
            if idx not in error_indices
        ]
        return query
    
    def translate_nl2sl(self, query):
        """
        convert natural language query to symbolic language
        """

        query = self.nl2sl_step1(query)  # convert natural language to hard logic
        query = self.nl2sl_step2(query)  # convert hard logic to py code
        print("temp_plan verifying...")
        query = self.nl2sl_step3(query)  # check the logical correctness of the py code
        query["hard_logic_py_iter_3"] = query["hard_logic_py"]

        if "error" in query:
            query["hard_logic_py"] = {}

        return query
    




if __name__ == "__main__":
    resp = NL2SL_INSTRUCTION("en-US").step1_format("I want to visit Singapore and Cameron Highlands")
    print(resp)