import sys, time, uuid, logging, json
from colorama import Fore, Style
from threading import Lock
from collections import defaultdict
from function_call_agent.agent_core import PathStateManager, SystemPromptBuilder
from function_call_agent.api_manage import APISemanticRecaller
from function_call_agent.tools import APISimulator, LLMClient, Graph_Search
from function_call_agent.logging.log_config import LogConfig, ProcessPhase

LogConfig.configure()  # Initialize log configuration

# ---------------------------
# Main Class Integration
# ---------------------------


def get_after_second_last_occurrence(s, sub):
    # Step 1: Find the last occurrence
    last_pos = s.rfind(sub)
    if last_pos == -1:
        return s

    # Step 2: Find the second last occurrence in [0,last_pos-1]
    second_last_pos = s.rfind(sub, 0, last_pos)

    if second_last_pos == -1:
        return s

    # Step 3: Extract content starting from second last occurrence (including the substring)
    return s[second_last_pos:]


class FunctionCallAgent:
    """Function Calling Agent (Integrated Modules)"""

    def __init__(self, model_name, graph_degree=None, search_type='alpha_beta',
                 prompt_txt="/Users/your/prompt.txt", clarify_llm=False,
                 api_llm=True, data_type=None):
        # Initialize modules
        self.model_name = model_name
        self.llm_client = LLMClient(model_name)
        self.clarify_llm_client = LLMClient('your_model_name')
        self.prompt_engine = SystemPromptBuilder(prompt_txt)
        self.graph_search = Graph_Search(search_type=search_type, graph_degree=graph_degree, data_type=data_type)
        self.apis = self.graph_search.apis
        self.api_recall = APISemanticRecaller(top_k=3, api_library=self.apis)

        # Preserve core logic
        self.api_llm = api_llm
        # if not self.api_llm:
        #     self.api_simulator = APISimulator(
        #         apis=self.apis,
        #         file_='/api_simulate_result.xlsx')
        self.max_retry = 6
        self.clarify_llm = clarify_llm
        self.path_lock = Lock()
        self.logger = logging.getLogger('FunctionCallAgent')   # Get logger instance
        self.logger.info(f"Initializing agent with model: {model_name}, search_type: {search_type}, data_type:{data_type}")  # Log initialization
        self.session_counter = 0  # Session counter
        self.active_sessions = {}  # Active session tracking

    def _get_session_id(self):
        """Generate timestamped session ID"""
        self.session_counter += 1
        return f"SID-{time.strftime('%Y%m%d%H%M%S')}-{self.session_counter:04d}"

    # Modified logging methods
    def _log_session_start(self, session_id, user_input):
        """Log session start marker"""
        self.logger.info(
            f"\n{Fore.YELLOW}╔{'═'*80}╗{Style.RESET_ALL}\n"
            f"{Fore.YELLOW}║{Style.RESET_ALL} "
            f"{Fore.CYAN}🏁 Session Start [{session_id}]{Style.RESET_ALL} | "
            f"{Fore.MAGENTA}User Input:{Style.RESET_ALL} {user_input}\n"
            f"{Fore.YELLOW}╚{'═'*80}╝{Style.RESET_ALL}"
        )

    def _log_model_response(self, session_id, llm_output):
        """Highlight model response"""
        self.logger.info(
            f"{Fore.GREEN}✨ Model Final Response [{session_id}]{Style.RESET_ALL}\n"
            f"{Fore.WHITE}{Style.BRIGHT}{json.dumps(llm_output, indent=2, ensure_ascii=False)}{Style.RESET_ALL}"
        )

    def get_parse_response(self, messages, other_llm=False):
        try:
            response = "No response, LLM call error"
            # Log original request
            self.logger.debug(
                f"{Fore.CYAN}LLM Request Parameters:{Style.RESET_ALL}\n"
                f"{json.dumps(messages[-1], indent=2, ensure_ascii=False)}"
            )
            if len(str(messages)) > 35000:
                raise ValueError('String length exceeds 25000')
            if other_llm:
                response = self.clarify_llm_client.get_response(messages)
                llm_output = self.clarify_llm_client.parse_response(response)
            else:
                response = self.llm_client.get_response(messages)
                llm_output = self.llm_client.parse_response(response)

            if self.logger.isEnabledFor(logging.DEBUG):
                self.logger.debug(
                    f"{Fore.BLUE}Full Response Body:{Style.RESET_ALL}\n"
                    f"{json.dumps(response, indent=2, ensure_ascii=False)}"
                )

            return response, llm_output
        except Exception as e:
            # Error logging with color
            self.logger.error(
                f"{Fore.RED}LLM Call Exception:{Style.RESET_ALL} {str(e)}"
                f"LLM Response: {response}",
                exc_info=True
            )
            raise ValueError('Content: ' + response)

    def get_react_response(self, user_input, history, path_id,
                           state_mgr, candidate_info, first_history=None):
        system_prompt = self.prompt_engine.build('\n'.join(candidate_info))
        if path_id is None:
            observation = self._construct_observation(state_mgr)
        else:
            observation = self._construct_observation(state_mgr)
        full_prompt = f"""Observation:\n{observation}\n++++++++++\nUser input：{user_input}"""
        if first_history is None:
            first_history = []
        messages = ([{"role": "system", "content": system_prompt}]
                    + first_history + [{"role": "user", "content": full_prompt}])

        response, llm_output = self.get_parse_response(messages)
        history.extend([{"role": "system", "content": '\n'.join(candidate_info)},
                        {"role": "user", "content": full_prompt},
                        {"role": "assistant", "content": llm_output}
                        ])
        return response, llm_output, history

    def _construct_observation(self, state_mgr) -> str:
        """Generate multi-path execution status report (adapted to new state structure)"""
        path_summaries = []
        global_params = {}
        failed_apis = defaultdict(int)

        with self.path_lock:
            # Iterate through all path states
            for path_id, path in state_mgr.state.items():
                # Build path summary
                summary = [
                    f"Path {path_id} ({path.get('status', 'pending')})",
                    f"APIs: {dict(path['called_apis'])} success | {dict(path['failed_apis'])} failed",
                    f"Params collected: {len(path['collected_params'])} items",
                    f"Latest result: {str(path['result'])}" if path['result'] else "Result pending"
                ]
                path_summaries.append("\n".join(summary))

                # Aggregate global data
                global_params.update(path['collected_params'])
                for api, messages in path['failed_apis'].items():
                    failed_apis[api] += len(messages)

        report_header = f"Execution Status Report"
        separator = "\n———————————————————————————————\n"

        report_body = separator.join([
            report_header,
            '\n\n'.join(path_summaries),
            f"Global Statistics:\n• Collected parameters: {len(global_params)} items\n• Failed API counts: {dict(failed_apis) or 'None'}"
        ])

        return separator + report_body + separator

    def _construct_path_observation(self, path_id, state_mgr) -> str:
        """Generate path-specific execution status report"""
        path_summaries = []
        global_params = {}
        failed_apis = defaultdict(int)

        with self.path_lock:
            # Get specific path status
            path = state_mgr.state[path_id]
            # Build path summary
            summary = [
                f"Successful API calls: {list(dict(path['called_apis']).keys())} ",
                f"Failed APIs: {dict(path['failed_apis'])} failures",
                f"Parameters: {path['collected_params']}",
                f"Result: {str(path['result'])}" if path['result'] else "None"
            ]
            path_summaries.append("\n".join(summary))

            # Aggregate global data
            global_params.update(path['collected_params'])
            for api, messages in path['failed_apis'].items():
                failed_apis[api] += len(messages)

        report_header = f"Execution Status Report"
        separator = "\n———————————————————————————————\n"

        report_body = separator.join([
            report_header,
            '\n\n'.join(path_summaries)
        ])

        return separator + report_body + separator

    def _parse_multi_intent_response(self, llm_output, state_mgr):
        """Parse multi-intent response and generate paths"""
        path_actions = {}

        # Convert dict response to list if needed
        if isinstance(llm_output, dict):
            llm_output = [llm_output]

        for idx, action_group in enumerate(llm_output, 1):
            path_id = f"path_{idx}_{uuid.uuid4().hex[:4]}"
            state_mgr.init_path(path_id)
            path_actions[path_id] = action_group

        return path_actions

    def get_intent_input(self, process_result, intent, user_input, success_apis='[]'):
        if process_result["status"] == "mock":
            add_input = process_result['clarify_response']
        else:
            add_input = 'None'
        return self.prompt_engine.intent_prompt.format(intent=intent, params=add_input,
                                                       user_input=user_input, success_apis=success_apis)

    def _execute_single_intent(self, path_id, action, intent, state_mgr, candidate_info, context,
                               user_input, target_api, record_result):
        """Execute single intent and return result"""
        history = []
        api_history = {'call_api': {}, 'retrieve_api': {}}
        process_result, history, candidate_info, api_history, record_result = self._process_llm_response(
            [action], path_id, history, state_mgr, candidate_info, context, api_history, target_api,
            record_result, first=True)
        if process_result["status"] == "complete":
            return {
                "path_id": path_id,
                "status": "success",
                "result": process_result['final_answer'] if 'final_answer' in process_result else process_result,
                "intent": intent
            }, history, api_history, record_result

        for _ in range(self.max_retry):
            path = state_mgr.state[path_id]

            response, react_llm_output, history = self.get_react_response(
                self.get_intent_input(process_result=process_result, intent=intent, user_input=user_input,
                                      success_apis=str(list(path['called_apis'].keys()))),
                history, path_id, state_mgr, candidate_info)
            process_result, history, candidate_info, api_history, record_result = self._process_llm_response(
                react_llm_output, path_id, history, state_mgr, candidate_info, context, api_history, target_api,
                record_result)
            if process_result["status"] == "complete":
                return {
                    "path_id": path_id,
                    "status": "success",
                    "result": ','.join(state_mgr.state[path_id]["result"]),
                    "intent": intent
                }, history, api_history, record_result
        # Fallback mechanism
        response, history = self._generate_final_answer(intent, history, state_mgr, path_id)
        return {
                    "path_id": path_id,
                    "status": "Support answer",
                    "result": response,
                    "intent": intent
                }, history, api_history, record_result

    def _process_llm_response(self, action_list: list, path_id, history, state_mgr, candidate_info, context,
                              api_history, target_api, record_result, first=False):
        self.logger.debug(f"Processing LLM response for path {path_id}")
        api_status = {}
        all_direct_answer = all(act["action"] == "direct_answer" for act in action_list)
        try:
            for one_action in action_list:
                # Clarification request
                if one_action["action"] == "clarify_intent":
                    if self.clarify_llm:
                        messages = [
                            {"role": "system", "content": self.prompt_engine.user_prompt.format(context=context)},
                            {"role": "user", "content": one_action.get("answer")}]
                        response, clarify_response = self.get_parse_response(messages, other_llm=True)
                        record_result.append([path_id, one_action["action"],
                                              {one_action.get("answer"): clarify_response}])
                        return {
                                'path_id': path_id,
                                "status": "mock",
                                "clarify_response": clarify_response,
                            }, history, candidate_info, api_history, record_result
                    else:
                        return {
                            'path_id': path_id,
                            "status": "complete",
                            "final_answer": one_action.get("answer", ""),
                            "data": state_mgr.state[path_id]["collected_params"]
                        }, history, candidate_info, api_history, record_result
                elif one_action["action"] == "call_api":
                    if self.api_llm:
                        if one_action["target_api"] in self.apis.keys():
                            messages = [
                                {"role": "system", "content": self.prompt_engine.api_prompt + str(context)},
                                {"role": "user", "content": str({'api_name': one_action["target_api"],
                                                                 'input_param:': one_action["params"]})}]
                            response, api_result = self.get_parse_response(messages, other_llm=True)
                        else:
                            api_result = {"status": "error", "data": "API not found, Can you retrieve_api again?", "type": "fail"}
                        data = 'data'
                    else:
                        api_result = self.api_simulator.run(one_action["target_api"], one_action["params"])
                        data = 'message'
                    api_status[one_action["target_api"]] = api_result
                    record_result.append([path_id, one_action["action"], {one_action["target_api"]: api_result}])
                    if api_result["status"] == "success":
                        state_mgr.state[path_id]["collected_params"].update({one_action["target_api"]: api_result[data]})
                        state_mgr.state[path_id]["called_apis"][one_action["target_api"]].append(api_result[data])
                        status_msg = f"{Fore.GREEN}SUCCESS{Style.RESET_ALL}"
                    else:
                        state_mgr.state[path_id]["failed_apis"][one_action["target_api"]].append(api_result[data])
                        status_msg = f"{Fore.RED}FAILED{Style.RESET_ALL}"

                    self.logger.info(
                        f"API Call | {one_action['target_api']} | Status: {status_msg} | "
                        f"Parameters: {one_action['params']} | Full Result: {api_result}"
                    )
                    api_history['call_api'][one_action["target_api"]] = [one_action["params"], api_result]
                elif one_action["action"] == "retrieve_api":
                    # Retrieve API using BGE
                    if (target_api is None) or not first:
                        available_apis = self.api_recall.retrieve(one_action['recall_description'])
                    else:
                        available_apis = self.api_recall.retrieve(one_action['recall_description'])
                        available_apis1 = [target_api] if type(target_api) == str else target_api
                        available_apis = available_apis1 + available_apis[:2]
                    self.logger.info(
                        f"API Recall Description | {one_action['recall_description']} | Results: {available_apis} | "
                    )
                    candidate_info_add = self.graph_search.run([api for api in available_apis])
                    if candidate_info[0] == 'No API candidate!!!  No API candidate!!!  No API candidate!!!':
                        candidate_info = [candidate_info_add]
                    elif candidate_info_add not in candidate_info:
                        candidate_info.append(candidate_info_add)
                    if len(candidate_info) > 3:
                        candidate_info = [candidate_info[0]] + candidate_info[-2:]
                    c_info = '\n'.join(candidate_info)
                    self.logger.info(
                        f"API Dependency Tree | {c_info } | "
                    )
                    record_result.append([path_id, one_action["action"], one_action['recall_description'],
                                          available_apis, candidate_info])
                    api_history['retrieve_api'][one_action['recall_description']] = [available_apis, candidate_info]
                elif one_action["action"] == "direct_answer":
                    state_mgr.state[path_id]["result"].append(one_action.get("answer", ""))
                    record_result.append([path_id, one_action["action"], one_action.get("answer", "")])
            final_status = "complete" if all_direct_answer else "continue"
            return {
                'path_id': path_id,
                "status": final_status,
                "observation": self._construct_observation(state_mgr),
                "api_status": api_status
            }, history, candidate_info, api_history, record_result
        except Exception as e:
            self.logger.error(f"Exception occurred processing LLM response: {str(e)}", exc_info=True)
            raise

    def _generate_final_answer(self, user_input, history, state_mgr, path_id=None):
        if path_id is None:
            observation = self._construct_observation(state_mgr)
        else:
            observation = self._construct_observation(state_mgr)
        prompt = self.prompt_engine.final_prompt.format(user_input=user_input, observation=observation)

        messages = [{"role": "user", "content": prompt}]
        response, llm_output = self.get_parse_response(messages)
        history.extend([{"role": "user", "content": prompt},
                        {"role": "assistant", "content": llm_output}
                        ])
        return llm_output, history

    def process_query(self, user_input, history=None, context='', target_api=None):
        """Main processing logic (preserving original structure)"""
        if history is None:
            history = []
        state_mgr = PathStateManager()
        candidate_info = ['No API candidate!!!  No API candidate!!!  No API candidate!!!']
        session_id = self._get_session_id()
        self.active_sessions[session_id] = {
            "start_time": time.time(),
            "user_input": user_input
        }
        self._log_session_start(session_id, user_input)
        all_history = {'main_history': None, 'sub_history': {}, 'api_history': []}
        record_result = []
        try:
            # Step 1: Get initial response
            response, llm_output, history = self.get_react_response(
                user_input, history, path_id=None, state_mgr=state_mgr, candidate_info=candidate_info, first_history=history)
            if response == '_ERROR_CALLING':
                raise ValueError(self.model_name + ': LLM call error _ERROR_CALLING')
            # Step 2: Parse multi-intent response
            path_actions = self._parse_multi_intent_response(llm_output[-1], state_mgr)
            self.logger.info(f"Generated execution paths: {len(path_actions)}")
            # Step 3: Execute all paths sequentially
            results = []
            for path_id, action in path_actions.items():
                self.logger.debug(f"Executing path {path_id}")
                # Initialize path state
                result, sub_history, api_history, record_result = self._execute_single_intent(
                    path_id, action, action['intent'], state_mgr, candidate_info, context, user_input, target_api,
                    record_result)
                all_history['sub_history'][path_id] = sub_history
                all_history['api_history'].append(api_history)
                results.append(result)
            # Step 4: Generate final response
            if len(path_actions) == 1:
                response = results[0]['result']
            else:
                # Aggregate answers
                response, history = self._generate_final_answer(user_input, history, state_mgr)
            all_history['main_history'] = history
            self._log_model_response(session_id, response)
            # Log session end
            duration = time.time() - self.active_sessions[session_id]["start_time"]
            self.logger.info(
                f"\n{Fore.YELLOW}╔{'═' * 80}╗{Style.RESET_ALL}\n"
                f"{Fore.YELLOW}║{Style.RESET_ALL} "
                f"{Fore.RED}⏹️ Session End [{session_id}]{Style.RESET_ALL} | "
                f"Duration: {duration:.2f}s\n"
                f"{Fore.YELLOW}╚{'═' * 80}╝{Style.RESET_ALL}"
            )
            return response, all_history, record_result

        except Exception as e:
            self.logger.critical(f"Process terminated abnormally: {str(e)}", exc_info=True)
            return {"status": "error", "message": str(e)}, all_history, record_result


# Test cases
if __name__ == "__main__":
    # Test Case 1
    agent = FunctionCallAgent('qwen25_14b', clarify_llm=True, search_type='alpha_beta')
    # query = " What is the current weather forecast for the location at latitude 40.7128 and longitude -74.0060 in metric units?"
    # print("Test Case 1:", agent.process_query(query)[0])
    # Test Case 2
    print("Test Case 2:", agent.process_query("can you help me delete my account? my username is zhou, my password is 8888?"))

    # Test Case 3
    # print("Test Case 3:", agent.process_query("""Is Beijing in China? \n And Can you add one exercise to program, My "param-exercise_id_3" is "1", My 'param-patient_id_10' is 123, My 'param-reps_1' is '10', My "param-rest_time_1" is "15", My 'param-sets_completed_2' is 3"""))
