import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "")))
os.chdir(os.path.dirname(os.path.abspath(__file__)))
from agent.hypertree.tools.planner.apis import HTPlanner, HyperTree, format_step
from langchain.schema import HumanMessage
from agent.hypertree.agents.hypertree_prompts import trip_execute_prompt, trip_expand_prompt, trip_select_prompt, trip_plan_generation_prompt, trip_select_prompt_zh, trip_expand_prompt_zh, trip_execute_prompt_zh, trip_plan_generation_prompt_zh
from agent.hypertree.agents.prompts import TRIP_PLAN_CONVERT
from openai import OpenAI
from agent.hypertree.utils.json_format import TravelItinerary
import pickle
import json
import json_repair
import time
import traceback

MAX_ITERATIONS = 30

class TripHTPlanner(HTPlanner):
    def __init__(self, backbone_llm, **kwargs):
        super().__init__(backbone_llm=backbone_llm)
        self.llm = backbone_llm

    def extract_cities_from_query(self, query):
        prompt = """
        You are a travel planner. You need to extract the cities from the query.
        Query: {query}

        The cities should be extracted from the query, and the cities should be in the format of "City1, City2, City3".
        """
        cities = self.llm([{"role": "user", "content": prompt.format(query=query)}])
        cities = cities.split(',')
        return cities
    
    def _extract_destination_from_query(self, user_query, locale):
        if locale == 'zh-CN':
            prompt = """
            你是一位专业的旅行助手。你的任务是从用户的旅行查询中提取目的地城市。
            - 只返回城市名称。
            - 如果有多个城市，请按顺序返回，并用英文逗号分隔。
            - 不要添加任何解释或引导性文字。

            用户查询: "{user_query}"

            目的地:
            """
        else:
            prompt = """
            You are an expert travel assistant. Your task is to extract the destination city or cities from the user's travel query.
            - Only return the city names.
            - If there are multiple cities, please return the cities in the order of the query, and separate them with commas.
            - Do not add any explanation or introductory text.

            User Query: "{user_query}"

            Destination(s):
            """
        
        formatted_prompt = prompt.format(user_query=user_query)
        destination = self.llm([{"role": "user", "content": formatted_prompt}])
        return destination.strip().strip('"').strip("'")

    def generate_responses(self, category):
        if self.visiting_city_number == 1:
            responses = {
                "[Transportation]": f"[Transportation from {self.org} to {self.dest}][Transportation from {self.dest} to {self.org}]",
                "[Accommodation]": f"[Accommodation for {self.dest}]",
                "[Attraction]": f"[Attraction for {self.dest}]",
            }
        else:
            responses = {
                "[Transportation]": f"[Transportation from {self.org} to City 1]",
                "[Accommodation]": f"[Accommodation for City 1]",
                "[Attraction]": f"[Attraction for City 1]"
            }
            for i in range(1, self.visiting_city_number):
                responses["[Transportation]"] += f"[Transportation from City {i} to City {i+1}]"
                responses["[Accommodation]"] += f"[Accommodation for City {i+1}]"
                responses["[Attraction]"] += f"[Attraction for City {i+1}]"
            responses["[Transportation]"] += f"[Transportation from City {self.visiting_city_number} to {self.org}]"
        return responses[category]
        

    def expand(self, node):
        if node.value in ["[Transportation]", "[Accommodation]", "[Attraction]"]:
            request =  self.generate_responses(node.value)
        elif self.transportation_pattern.fullmatch(node.value):
            request =  self.rules['[Transportation from A to B]']
        elif self.accommodation_pattern.fullmatch(node.value):
            request =  self.rules['[Accommodation for A]']
        elif node.value in self.rules:
            request =  self.rules[node.value]
        try: 
            request = ['[' + item + ']' for item in request.strip('[]').split('][')]
        except:
            self.str_to_convert = request
            request = format_step(self.llm([{"role": "user", "content": self._build_convert_prompt()}]))
            request = ['[' + item + ']' for item in request.strip('[]').split('][')]

        node.all = [request]
        node.branch = 0
        children_value_list=node.all[node.branch]
        for value in children_value_list:
            if value!= node.value:
                node.children.append(HyperTree(value))

    def run(self, given_information, query, number, cache_folder_path, result_folder_path, load_cache=True):
        self.given_information = given_information
        self.current_tree = HyperTree('[Plan]')
        self.selected_node = self.current_tree
        self.number = number
        self.query = query['userQuery']
        self.locale = query.get('locale', 'en-US')
        self.dest = query.get('arrive', '')
        if not self.dest:
            self.dest = self._extract_destination_from_query(self.query, self.locale)

        if self.locale == 'zh-CN':
            self.select_prompt = trip_select_prompt_zh
            self.expand_prompt = trip_expand_prompt_zh
            self.execute_prompt = trip_execute_prompt_zh
            self.plan_generation_prompt = trip_plan_generation_prompt_zh
        else:
            self.select_prompt = trip_select_prompt
            self.expand_prompt = trip_expand_prompt
            self.execute_prompt = trip_execute_prompt
            self.plan_generation_prompt = trip_plan_generation_prompt
        # self.visiting_city_number = len(self.extract_cities_from_query(query['userQuery'])) "Tokyo, Fujikawaguchiko, Kamakura,
        if ',' in self.dest:
            self.dest = self.dest.split(',')
        elif '，' in self.dest:
            self.dest = self.dest.split('，')
        else:
            self.dest = [self.dest]
        self.visiting_city_number = len(self.dest) if self.dest else 1
        self.org = query['departure']
        self.day = query['day']
        self.preference = query['preference']

        self.rules = {'[Plan]':'[Transportation][Accommodation][Attraction]','[Transportation from A to B]':'[Self-driving][Taxi][Train][Flight]',\
                      '[Self-driving]':'[transportation availability][transportation preference][transportation cost]',\
                        '[Taxi]':'[transportation availability][transportation preference][transportation cost]',\
                        '[Train]':'[transportation availability][transportation preference][transportation cost]',\
                            '[Flight]':'[transportation availability][transportation preference][transportation cost]',\
                                '[Accommodation for A]':'[minimum stay][house rule][room type][accommodation cost]'}
        iteration_count = 0
        message_id = query['message_id']

        cache_path = f"{cache_folder_path}/org_{self.org}_dest_{self.dest}_day_{self.day}_final_tree.pkl"

        if load_cache and os.path.exists(cache_path):
            print(f"{cache_path} exist, load cache.")
            self.final_tree = pickle.load(open(cache_path, "rb"))
        else:       
            while True:
                if iteration_count >= MAX_ITERATIONS:
                    break
                if not self.selected_node:
                    branch = self.current_tree.postorder_traversal()
                    if not branch:
                        break
                else:
                    self.expand(self.selected_node)
                retry_times = 3
                while retry_times > 0:
                    try:
                        selected_index = self.select(self.current_tree)
                        if self.leaves==[]:
                            break
                        self.selected_node = self.leaves[int(selected_index)]
                        break
                    except:
                        print(f"case: {self.number} select node failed: {selected_index}, current leaves: {self.leaves}")
                        tmp_prompt = self._build_select_prompt()
                        print(f"select prompt:{tmp_prompt}")
                        traceback.print_exc()
                        retry_times -= 1
                iteration_count += 1
                print("--------------------------------")
                print(f"Iteration {iteration_count}")
                print(self.current_tree.show().rstrip('\n'))
                print("--------------------------------")
            self.final_tree = self.current_tree.show().rstrip('\n')
            with open(cache_path, "wb") as f:
                pickle.dump(self.final_tree, f)
        self.thinking_process = self.llm([{"role": "user", "content": self._build_execute_prompt()}])
        travel_plan = self.plan_generate()
        return travel_plan


    
    def plan_generate(self):
        raw_plan = self.llm([{"role": "user", "content": self._build_plan_generation_prompt()}])
        # plan_convert_prompt = TRIP_PLAN_CONVERT+"\nTEXT:\n"+raw_plan+"\nJSON:\n"
        # response = self.client.chat.completions.parse(
        #     model="659-gpt-4o__2024-11-20",
        #     messages=[{"role": "user", "content": plan_convert_prompt}],
        #     response_format=TravelItinerary,
        #     max_tokens=10000
        # )
        # plan = response.choices[0].message.content
        plan = raw_plan.lstrip("```json").rstrip('```')
        plan = json_repair.loads(plan)
        plan = {"idx": self.number, "query": self.query, "plan": plan, "final_tree": self.final_tree, "thinking_process": self.thinking_process, "raw_plan": raw_plan}
        return plan

    def _build_select_prompt(self) -> str:
        return self.select_prompt.format(
            query = self.query,
            user_preferences = self.preference,
            current_tree = self.current_tree.show().rstrip('\n'),
            leaves = self.leaves_dict)
        
    def _build_expand_prompt(self) -> str:
        return self.expand_prompt.format(
            query = self.query,
            user_preferences = self.preference,
            current_tree = self.current_tree.show().rstrip('\n'),
            leaves = self.leaves_dict)

    def _build_execute_prompt(self) -> str:
        return self.execute_prompt.format(
            given_information = self.given_information,
            query = self.query,
            user_preferences = self.preference,
            solution_strategy = self.final_tree)
        
    def _build_plan_generation_prompt(self) -> str:
        return self.plan_generation_prompt + "[Reference Data]\n" + self.given_information + "\n[Query]\n" + self.query + "\n[Thinking Process]\n" + self.thinking_process + "\n[Your Travel Plan]\n"

