import copy
import json
import os
import torch.distributed as dist
from typing import Dict, Iterator, List, Literal, Tuple, Union
from qwen_agent.agents.fncall_agent import FnCallAgent
from qwen_agent.llm.schema import ASSISTANT, Message
from qwen_agent.utils.parallel_executor import parallel_exec, serial_exec

OBS_START = '<tool_response>'
OBS_END = '\n</tool_response>'

MAX_LLM_CALL_PER_RUN = int(os.getenv("MAX_LLM_CALL_PER_RUN", 2))
print(f'Running with MAX_LLM_CALL_PER_RUN = {MAX_LLM_CALL_PER_RUN}')


global_rank = dist.get_rank()
world_size = dist.get_world_size()
mp_size = int(os.getenv("GPUS_PER_NODE"))  # model parallel size
dp_size = world_size // mp_size            # data parallel size
dp_rank = global_rank // mp_size           # which machine/node
mp_rank = global_rank % mp_size            # which GPU inside that node

def create_model_parallel_group(world_size, mp_size):
    groups = []
    for start_rank in range(0, world_size, mp_size):
        ranks_in_node = list(range(start_rank, start_rank + mp_size))
        group = dist.new_group(ranks=ranks_in_node)
        groups.append(group)
    return groups

# figure out which group this current rank belongs to, dp_rank = node_id
model_parallel_groups = create_model_parallel_group(world_size, mp_size)
model_parallel_group = model_parallel_groups[dp_rank]


class NousAgent(FnCallAgent):
    """A special POT agent"""

    def _ask_code_exec(self,
                       index: int,
                       t_index: int,
                       tool_name: str,
                       tool_args: Union[str, dict] = '{}',
                       messages: List[Union[Dict, Message]] = None,
                       tool_res: str = '',
                       **kwargs) -> tuple:
        if tool_res:
            return index, t_index, tool_res
        res = self._call_tool(tool_name=tool_name, tool_args=tool_args, messages=messages)
        return index, t_index, res

    def _run_batch(self, messages_batch: List[List[Union[Dict, Message]]], lang_batch: List[str],
                   **kwargs) -> List[List[Message]]:

        llm_fn = kwargs.get('llm_fn', None)
        llm_batch_enabled = kwargs.get('llm_batch_enabled', False)
        if not llm_batch_enabled:
            return super()._run_batch(messages_batch=messages_batch, lang_batch=lang_batch, **kwargs)
        assert llm_fn, 'Please pass in llm_fn that supports batch inference externally'
        batch_size = len(messages_batch)
        origin_messages_batch = copy.deepcopy(messages_batch)
        messages_batch = copy.deepcopy(messages_batch)
        responses_batch = [[Message(role=ASSISTANT, content='')] for i in range(batch_size)]
        in_processing = [i for i in range(batch_size)]  # all index of inputted messages
        num_llm_calls_available = MAX_LLM_CALL_PER_RUN
        is_first_run = True
        while len(in_processing) > 0 and num_llm_calls_available > 0:
            num_llm_calls_available -= 1
            assert len(in_processing) == len(messages_batch)
            batch_size = len(messages_batch)
            for idx in range(batch_size):
                for i, msg in enumerate(messages_batch[idx]):
                    if isinstance(msg.content, list):
                        assert len(msg.content) == 1
                        messages_batch[idx][i].content = msg.content[0].text
            output_batch = llm_fn(messages=[[msg.model_dump() for msg in messages] for messages in messages_batch],
                                  is_first_run=is_first_run,
                                  generate_cfg={
                                      'stop': [OBS_START],
                                      'top_k': 1
                                  })
            is_first_run = False
            output_batch = [
                [Message(**msg) if isinstance(msg, dict) else msg for msg in output] for output in output_batch
            ]

            if len(output_batch) == 0:
                return []

            _in_processing = copy.deepcopy(in_processing)
            data = []
            skip_code_execution = False
            for idx, output in zip(_in_processing, output_batch):
                if not output or not output[-1].content:
                    # this data meet the end: exceed max_length
                    in_processing.remove(idx)
                    continue
                responses_batch[idx][0].content += output[-1].content
                responses_batch[idx][0].extra = output[-1].extra

                # Determine whether to call the tool
                _rsp = output[-1].content
                if '</think>' in _rsp:
                    _rsp = _rsp.split('</think>')[-1]
                if '<tool_call>' in _rsp and '</tool_call>' in _rsp:
                    _tool_list = _rsp.split('<tool_call>')
                    for _ti, _t in enumerate(_tool_list[1:]):
                        try:
                            _tt = json.loads(_t.split('</tool_call>')[0].strip())
                            # record for parallel exec code
                            data.append({
                                'index': idx,
                                't_index': _ti,
                                'tool_name': _tt['name'],
                                'tool_args': _tt['arguments'],
                                'messages': origin_messages_batch[idx]
                            })
                        except Exception:
                            data.append({
                                'index': idx,
                                't_index': _ti,
                                'tool_name': '',
                                'tool_args': '',
                                'messages': origin_messages_batch[idx],
                                'tool_res': 'Error: Illegal JSON.'
                            })
                else:
                    # this data meet the end
                    in_processing.remove(idx)
                    continue

                # special logic for llm_fn
                if output[-1].extra and not output[-1].extra.get('is_last_rank', True):
                    skip_code_execution = True
                    continue

            if not skip_code_execution:
                if len(data) == 0 or num_llm_calls_available == 0:
                    break
                # Exec code and get the messages_batch for next turn
                # Only first gpu of each machine should actually call the external tool(s)
                if mp_rank == 0:
                    print(f'[exec_start] (rank-{dp_rank*mp_size}) on {len(output_batch)} messages, {len(data)} tasks with {num_llm_calls_available} tries left')
                    results = parallel_exec(self._ask_code_exec, data, max_workers=1, jitter=0.5)
                    # results = serial_exec(self._ask_code_exec, data)
                    print(f'[exec_done] (rank-{dp_rank*mp_size}) on {len(output_batch)} messages, {len(data)} tasks with {num_llm_calls_available} tries left')
                else:
                    results = None
                # must be a list, even if single item
                obj_list = [results]
                dist.broadcast_object_list(obj_list, src=dp_rank*mp_size, group=model_parallel_group)
                results = obj_list[0]
                ordered_results = sorted(results, key=lambda x: (x[0], x[1]))
                new_messages_batch = []
                for ri in range(len(ordered_results)):
                    idx, t_idx, observation = ordered_results[ri]

                    observation = observation.strip()
                    observation = f'{OBS_START}\n{observation}{OBS_END}'
                    if not responses_batch[idx][0].content.endswith('\n'):
                        responses_batch[idx][0].content += '\n'
                    responses_batch[idx][0].content += observation

                    if origin_messages_batch[idx][-1].role == ASSISTANT:
                        origin_messages_batch[idx][-1] = copy.deepcopy(responses_batch[idx][0])
                    else:
                        origin_messages_batch[idx].extend(copy.deepcopy(responses_batch[idx]))

                    if ri > 0 and ordered_results[ri][0] == ordered_results[ri - 1][0]:
                        new_messages_batch[-1] = origin_messages_batch[idx]
                    else:
                        new_messages_batch.append(origin_messages_batch[idx])
            else:
                # only get the messages_batch for next turn
                new_messages_batch = []
                for idx in in_processing:
                    new_messages_batch.append(origin_messages_batch[idx])
            messages_batch = new_messages_batch

        return responses_batch

    def _run(self, messages: List[Message], lang: Literal['en', 'zh'] = 'en', **kwargs) -> Iterator[List[Message]]:
        raise NotImplementedError

    def _detect_tool(self, text: str) -> Tuple[bool, str, str, str]:
        name = args = ''
        if '<tool_call>' in text:
            try:
                a = json.loads(text.split('<tool_call>')[-1].split('</tool_call>')[0].strip())
            except:
                return False, '', '', text
            name = a['name']
            args = a['arguments']
            text = text.split('<tool_call>')[0]
        return (name != ''), name, args, text
