import json
from agent.base import BaseAgent
from agent.hypertree.tools.planner.trip_apis import TripHTPlanner

class HyperTreeAgent(BaseAgent):
    def __init__(self, **kwargs):
        super().__init__(name="HyperTree", **kwargs)
        self.planner = TripHTPlanner(**kwargs)

    def run(self, query, **kwargs):
        """
        query: query_i dict from run_exp_flexible.py
        """
        try:
            # 1. Extract given_information from messages
            messages = query.get("messages", [])
            if isinstance(messages, str):
                try:
                    messages = json.loads(messages)
                except json.JSONDecodeError:
                    messages = []

            last_message_content = ""
            if messages:
                last_message = messages[-1]
                if isinstance(last_message, dict) and 'content' in last_message:
                    content_str = last_message['content']
                    # The content might be a stringified JSON.
                    if isinstance(content_str, str) and content_str.startswith('{'):
                         try:
                            content_data = json.loads(content_str)
                            if isinstance(content_data, dict) and 'content' in content_data:
                                last_message_content = content_data['content']
                            else:
                                last_message_content = content_str
                         except json.JSONDecodeError:
                            last_message_content = content_str
                    else:
                        last_message_content = str(content_str)

            given_information = ""
            if '[User Query]' in last_message_content:
                given_information = last_message_content.split('[User Query]')[0]
            elif '[用户查询]' in last_message_content:
                given_information = last_message_content.split('[用户查询]')[0]

            # 2. Get case number
            number = query.get('case_index', 0)
            
            # 3. Get load_cache flag
            load_cache = kwargs.get('load_cache', True)

            # 4. Call planner's run method
            travel_plan = self.planner.run(
                given_information=given_information,
                query=query,
                number=number,
                cache_folder_path=self.log_dir,
                result_folder_path=None, # Not used by the run method
                load_cache=load_cache
            )

            # 5. Format the output
            final_plan = travel_plan
            final_plan['llm_response'] = travel_plan.get('raw_plan', '')
            
            agent_success = bool(final_plan['llm_response'])

            return agent_success, final_plan

        except Exception as e:
            import traceback
            print(f"HyperTreeAgent failed: {e}")
            print(traceback.format_exc())
            return False, {"error": str(e)}
