import copy
import json
from typing import Dict, Iterator, List, Literal, Optional, Tuple, Union

import json5

from qwen_agent.agents.fncall_agent import FnCallAgent
from qwen_agent.llm import BaseChatModel
from qwen_agent.llm.schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, Message
from qwen_agent.tools import BaseTool
from qwen_agent.utils.parallel_executor import parallel_exec
from qwen_agent.utils.utils import print_traceback

OBS_START = '```output'
OBS_END = '\n```\n'

DEFAULT_TOOL_NAME = 'python_executor'
MAX_LLM_CALL_PER_RUN = 3


def extract_program(result: str, last_only=True):
    """
    extract the program after "```python", and before "```"
    """
    program = ''
    start = False
    for line in result.split('\n'):
        if line.startswith('```python') or line.endswith('```python'):
            if last_only:
                program = ''  # only extract the last program
            else:
                program += '\n# ========\n'
            start = True
        elif line.startswith('```'):
            start = False
        elif start:
            program += line + '\n'
    if start:
        # the code is incomplete
        program = ''
    return program


class POTAgent(FnCallAgent):
    """A special POT agent"""

    def __init__(self,
                 function_list: Optional[List[Union[str, Dict, BaseTool]]] = None,
                 llm: Optional[Union[Dict, BaseChatModel]] = None,
                 system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
                 name: Optional[str] = None,
                 description: Optional[str] = None,
                 files: Optional[List[str]] = None,
                 **kwargs):
        super().__init__(function_list=[DEFAULT_TOOL_NAME] + (function_list or []),
                         llm=llm,
                         system_message=system_message,
                         name=name,
                         description=description,
                         files=files,
                         **kwargs)

    def _run_batch(self, messages_batch: List[List[Union[Dict, Message]]], lang_batch: List[str],
                   **kwargs) -> List[List[Message]]:

        def _ask_code_exec(index: int,
                           tool_name: str,
                           tool_args: Union[str, dict] = '{}',
                           messages: List[Union[Dict, Message]] = None) -> tuple:
            res = self._call_tool(tool_name=tool_name, tool_args=tool_args, messages=messages)
            return index, res

        llm_fn = kwargs.get('llm_fn', None)
        llm_batch_enabled = kwargs.get('llm_batch_enabled', False)
        if not llm_batch_enabled:
            return super()._run_batch(messages_batch=messages_batch, lang_batch=lang_batch, **kwargs)
        assert llm_fn, 'Please pass in llm_fn that supports batch inference externally'
        batch_size = len(messages_batch)
        origin_messages_batch = copy.deepcopy(messages_batch)
        messages_batch = copy.deepcopy(messages_batch)
        responses_batch = [[Message(role=ASSISTANT, content='')] for i in range(batch_size)]
        in_processing = [i for i in range(batch_size)]  # all index of inputted messages
        num_llm_calls_available = MAX_LLM_CALL_PER_RUN
        is_first_run = True
        while len(in_processing) > 0 and num_llm_calls_available > 0:
            num_llm_calls_available -= 1
            assert len(in_processing) == len(messages_batch)
            batch_size = len(messages_batch)
            for idx in range(batch_size):
                for i, msg in enumerate(messages_batch[idx]):
                    if isinstance(msg.content, list):
                        assert len(msg.content) == 1
                        messages_batch[idx][i].content = msg.content[0].text
            output_batch = llm_fn(messages=[[msg.model_dump() for msg in messages] for messages in messages_batch],
                                  is_first_run=is_first_run,
                                  generate_cfg={
                                      'stop': [OBS_START],
                                      'top_k': 1
                                  })
            is_first_run = False
            output_batch = [
                [Message(**msg) if isinstance(msg, dict) else msg for msg in output] for output in output_batch
            ]

            if len(output_batch) == 0:
                return []

            _in_processing = copy.deepcopy(in_processing)
            data = []
            skip_code_execution = False
            for idx, output in zip(_in_processing, output_batch):
                if not output or not output[-1].content:
                    # this data meet the end: exceed max_length
                    in_processing.remove(idx)
                    continue
                responses_batch[idx][0].content += output[-1].content
                responses_batch[idx][0].extra = output[-1].extra

                # Determine whether to call the tool
                has_action, action, action_input, thought = self._detect_tool(output[-1].content)
                if not has_action:
                    # this data meet the end
                    in_processing.remove(idx)
                    continue

                # special logic for llm_fn
                if output[-1].extra and not output[-1].extra.get('is_last_rank', True):
                    skip_code_execution = True
                    continue

                # record for parallel exec code
                data.append({
                    'index': idx,
                    'tool_name': action,
                    'tool_args': action_input,
                    'messages': origin_messages_batch[idx]
                })
            if not skip_code_execution:
                # Exec code and get the messages_batch for next turn
                results = parallel_exec(_ask_code_exec, data, max_workers=1, jitter=0.5)
                ordered_results = sorted(results, key=lambda x: x[0])
                new_messages_batch = []
                for idx, observation_list in ordered_results:
                    try:
                        observation_list = json5.loads(observation_list)
                    except Exception:
                        if '/qwen_agent/' in observation_list or 'Traceback:' in observation_list:
                            observation_list = ['', observation_list]
                        else:
                            observation_list = [observation_list, 'Done']
                    if observation_list[-1] == 'Done':
                        observation = observation_list[0]
                        code_exec = 1
                    else:
                        observation = observation_list[-1]
                        code_exec = 0
                    if 'code_exec' in responses_batch[idx][0].extra:
                        responses_batch[idx][0].extra['code_exec'].append(code_exec)
                    else:
                        responses_batch[idx][0].extra['code_exec'] = [code_exec]
                    observation = observation.strip()
                    observation = f'{OBS_START}\n{observation}{OBS_END}'
                    if not responses_batch[idx][0].content.endswith('\n'):
                        responses_batch[idx][0].content += '\n'
                    responses_batch[idx][0].content += observation

                    if origin_messages_batch[idx][-1].role == ASSISTANT:
                        origin_messages_batch[idx][-1] = copy.deepcopy(responses_batch[idx][0])
                    else:
                        origin_messages_batch[idx].extend(copy.deepcopy(responses_batch[idx]))
                    new_messages_batch.append(origin_messages_batch[idx])
            else:
                # only get the messages_batch for next turn
                new_messages_batch = []
                for idx in in_processing:
                    new_messages_batch.append(origin_messages_batch[idx])
            messages_batch = new_messages_batch

        return responses_batch

    def _run(self, messages: List[Message], lang: Literal['en', 'zh'] = 'en', **kwargs) -> Iterator[List[Message]]:
        llm_fn = kwargs.get('llm_fn', None)
        text_messages = copy.deepcopy(messages)
        num_llm_calls_available = MAX_LLM_CALL_PER_RUN
        response: str = ''
        while num_llm_calls_available > 0:
            num_llm_calls_available -= 1

            for i, msg in enumerate(text_messages):
                if isinstance(msg.content, list):
                    assert len(msg.content) == 1
                    text_messages[i].content = msg.content[0].text

            # Display the streaming response
            output = []
            # TODO: change this hotfix
            if llm_fn:
                # example of messages: [{'role': 'user', 'content': 'xxx'}, {'role': 'assistant', 'content': 'xxx'}]
                output = llm_fn(messages=[msg.model_dump() for msg in text_messages],
                                generate_cfg={
                                    'stop': [OBS_START],
                                    'top_k': 1
                                })
                output = [Message(**msg) if isinstance(msg, dict) else msg for msg in output]
                if output:

                    yield [Message(role=ASSISTANT, content=response + output[-1].content, extra=output[-1].extra)]
            else:
                for output in self.llm.continue_assistant_response(messages=text_messages,
                                                                   generate_cfg={
                                                                       'stop': [OBS_START],
                                                                       'top_k': 1
                                                                   },
                                                                   stream=True):
                    if output:
                        yield [Message(role=ASSISTANT, content=response + output[-1].content, extra=output[-1].extra)]

            # Accumulate the current response
            # The generated content before stop_word
            if output:
                response += output[-1].content

            has_action, action, action_input, thought = self._detect_tool(output[-1].content)
            if not has_action:
                break
            if output[-1].extra and not output[-1].extra['is_last_rank']:
                # special logic for llm_fn
                continue

            # Add the tool result
            observation = self._call_tool(action, action_input, messages=messages, **kwargs)
            try:
                observation_list = json5.loads(observation)
                if observation_list[-1] == 'Done':
                    observation = observation_list[0]
                else:
                    observation = observation_list[-1]
            except Exception:
                print_traceback()
            observation = observation.strip()
            observation = f'{OBS_START}\n{observation}{OBS_END}'

            # Accumulate the current exec result
            if not response.endswith('\n'):
                response += '\n'
            response += observation
            current_rsp = Message(role=ASSISTANT, content=response, extra=output[-1].extra)
            yield [current_rsp]

            if text_messages[-1].role == ASSISTANT:
                text_messages[-1] = current_rsp
            else:
                text_messages.append(current_rsp)

    def _detect_tool(self, text: str) -> Tuple[bool, str, str, str]:
        program = extract_program(text)
        if program:
            program = json.dumps({'code': program}, ensure_ascii=False)
        return (program != ''), DEFAULT_TOOL_NAME, program, text
