import copy
import json
import os
import torch.distributed as dist
from typing import Dict, List, Union

from qwen_agent.agents.rl.nous_agent_continue import NousAgent
from qwen_agent.llm.schema import ASSISTANT, USER, Message
from qwen_agent.utils.parallel_executor import parallel_exec, serial_exec
from ray.util import pdb


OBS_START = '<tool_response>'
OBS_END = '\n</tool_response>'

MAX_LLM_CALL_PER_RUN = int(os.getenv('MAX_LLM_CALL_PER_RUN', 8))
MAX_TOOL_CALL_WORKERS = int(os.getenv('MAX_TOOL_CALL_WORKERS', 20))
MAX_TOOL_CALL_INTERVAL = float(os.getenv('MAX_TOOL_CALL_INTERVAL', 0.01))
print(f'Running with MAX_LLM_CALL_PER_RUN = {MAX_LLM_CALL_PER_RUN}')
print(f'Running with MAX_TOOL_CALL_WORKERS = {MAX_TOOL_CALL_WORKERS}')
print(f'Running with MAX_TOOL_CALL_INTERVAL = {MAX_TOOL_CALL_INTERVAL}')


global_rank = dist.get_rank()
world_size = dist.get_world_size()
mp_size = int(os.getenv("TP_SIZE"))        # model parallel size
dp_size = world_size // mp_size            # data parallel size
dp_rank = global_rank // mp_size           # which machine/node
mp_rank = global_rank % mp_size            # which GPU inside that node

def create_model_parallel_group(world_size, mp_size):
    groups = []
    for start_rank in range(0, world_size, mp_size):
        ranks_in_node = list(range(start_rank, start_rank + mp_size))
        group = dist.new_group(ranks=ranks_in_node)
        groups.append(group)
    return groups

# figure out which group this current rank belongs to, dp_rank = node_id
model_parallel_groups = create_model_parallel_group(world_size, mp_size)
model_parallel_group = model_parallel_groups[dp_rank]



class NousAgentChat(NousAgent):
    """A special POT agent"""

    def _run_batch(self, messages_batch: List[List[Union[Dict, Message]]], lang_batch: List[str],
                   **kwargs) -> List[List[Message]]:

        llm_fn = kwargs.get('llm_fn', None)
        llm_batch_enabled = kwargs.get('llm_batch_enabled', False)
        if not llm_batch_enabled:
            return super()._run_batch(messages_batch=messages_batch, lang_batch=lang_batch, **kwargs)
        assert llm_fn, 'Please pass in llm_fn that supports batch inference externally'
        batch_size = len(messages_batch)
        origin_messages_batch = copy.deepcopy(messages_batch)
        messages_batch = copy.deepcopy(messages_batch)

        in_processing = [i for i in range(batch_size)]  # all index of inputted messages
        num_llm_calls_available = MAX_LLM_CALL_PER_RUN
        is_first_run = True
        while len(in_processing) > 0 and num_llm_calls_available > 0:
            num_llm_calls_available -= 1
            assert len(in_processing) == len(messages_batch)
            batch_size = len(messages_batch)
            for idx in range(batch_size):
                for i, msg in enumerate(messages_batch[idx]):
                    if isinstance(msg.content, list):
                        assert len(msg.content) == 1
                        messages_batch[idx][i].content = msg.content[0].text
            output_batch = llm_fn(messages=[[msg.model_dump() for msg in messages] for messages in messages_batch],
                                  is_first_run=is_first_run,
                                  generate_cfg={
                                      'stop': [OBS_START],
                                      'top_k': 1
                                  })
            is_first_run = False
            output_batch = [
                [Message(**msg) if isinstance(msg, dict) else msg for msg in output] for output in output_batch
            ]

            if len(output_batch) == 0:
                return []

            _in_processing = copy.deepcopy(in_processing)
            data = []
            for idx, output in zip(_in_processing, output_batch):
                if not output or not output[-1].content:
                    # this data meet the end: exceed max_length
                    in_processing.remove(idx)
                    continue
                origin_messages_batch[idx].append(
                    Message(role=ASSISTANT, content=output[-1].content, extra=output[-1].extra))

                # Determine whether to call the tool
                _rsp = output[-1].content
                if '</think>' in _rsp:
                    _rsp = _rsp.split('</think>')[-1]
                if '<tool_call>' in _rsp and '</tool_call>' in _rsp:
                    _tool_list = _rsp.split('<tool_call>')
                    for _ti, _t in enumerate(_tool_list[1:]):
                        if '<code>' in _t:
                            _tool = _t.split('</tool_call>')[0].strip()
                            _snips = _tool.split('<code>')
                            fn = None
                            try:
                                for i, _s in enumerate(_snips):
                                    if i == 0:
                                        fn = json.loads(_s)
                                    elif i == 1:
                                        # TODO: support more flexible params
                                        code = _s.replace('</code>', '')
                                        fn['arguments']['code'] = code
                                data.append({
                                    'index': idx,
                                    't_index': _ti,
                                    'tool_name': fn['name'],
                                    'tool_args': fn['arguments'],
                                    'messages': origin_messages_batch[idx]
                                })
                            except Exception:
                                data.append({
                                    'index': idx,
                                    't_index': _ti,
                                    'tool_name': '',
                                    'tool_args': '',
                                    'messages': origin_messages_batch[idx],
                                    'tool_res': 'Error: Illegal Tool Format.'
                                })
                        else:
                            try:
                                _tt = json.loads(_t.split('</tool_call>')[0].strip())
                                # record for parallel exec code
                                data.append({
                                    'index': idx,
                                    't_index': _ti,
                                    'tool_name': _tt['name'],
                                    'tool_args': _tt['arguments'],
                                    'messages': origin_messages_batch[idx]
                                })
                            except Exception:
                                data.append({
                                    'index': idx,
                                    't_index': _ti,
                                    'tool_name': '',
                                    'tool_args': '',
                                    'messages': origin_messages_batch[idx],
                                    'tool_res': 'Error: Illegal JSON.'
                                })

                else:
                    # this data meet the end
                    in_processing.remove(idx)
                    continue

            if len(data) == 0 or num_llm_calls_available == 0:
                break
            # Exec code and get the messages_batch for next turn
            # Only first gpu of each machine should actually call the external tool(s)
            if mp_rank == 0:
                print(f'[exec_start] (rank-{dp_rank*mp_size}) on {len(output_batch)} messages, {len(data)} tasks with {num_llm_calls_available} tries left')
                results = parallel_exec(self._ask_code_exec, data, max_workers=1, jitter=MAX_TOOL_CALL_INTERVAL)
                # results = serial_exec(self._ask_code_exec, data)
                print(f'[exec_done] (rank-{dp_rank*mp_size}) on {len(output_batch)} messages, {len(data)} tasks with {num_llm_calls_available} tries left')
            else:
                results = None
            # must be a list, even if single item
            obj_list = [results]
            dist.broadcast_object_list(obj_list, src=dp_rank*mp_size, group=model_parallel_group)
            results = obj_list[0]
            ordered_results = sorted(results, key=lambda x: (x[0], x[1]))
            for ri in range(len(ordered_results)):
                idx, t_idx, observation = ordered_results[ri]

                observation = observation.strip()

                observation = f'{OBS_START}\n{observation}{OBS_END}'

                if ri > 0 and ordered_results[ri][0] == ordered_results[ri - 1][0]:
                    assert origin_messages_batch[idx][-1].role == USER, origin_messages_batch
                    origin_messages_batch[idx][-1].content += f'\n{observation}'
                else:
                    origin_messages_batch[idx].append(Message(role=USER, content=observation))
            messages_batch = []

            # for idx, msgs in enumerate(origin_messages_batch):
            #     if msgs[-1].role == USER:
            #         messages_batch.append(msgs)
            #         current_holder.append(idx)
            for idx in in_processing:
                msgs = origin_messages_batch[idx]
                if msgs[-1].role == USER:
                    messages_batch.append(msgs)
                    # current_holder.append(idx)
            if len(in_processing) != len(messages_batch):
                pdb.set_trace()

            # for msgs in origin_messages_batch:
            #     if msgs[-1].role == USER:
            #         messages_batch.append(msgs)


        return [msgs[1:] for msgs in origin_messages_batch]


#
# origin_messages_batch = [
#     0 [Message(role=USER), Message(role=ASSISTANT), Message(role=USER), Message(role=ASSISTANT)],
#     1 [Message(role=USER), Message(role=ASSISTANT)],
#     2 [Message(role=USER), Message(role=ASSISTANT), Message(role=USER), Message(role=ASSISTANT)],
#     3 [Message(role=USER), Message(role=ASSISTANT)]
#     ]
