"""
orgnize the context of inference process for a data item
"""
import os
import json
import re
import copy
from abc import ABC
from typing import List, Union
from qwen_agent.llm.schema import ContentItem, Message
from dataclasses import asdict

from .node import Node

from ray.util import pdb

TOOL_ERROR_TOLERANCE = int(os.getenv('TOOL_ERROR_TOLERANCE', 0))
TOLERANCE_TOOLS = os.getenv('TOLERANCE_TOOLS', '').split(',')

def process_text(text):
    prefix = "<|im_start|>assistant\n"
    parts = text.split(prefix, maxsplit=1)
    return parts[1].rstrip('\n') if len(parts) > 1 else text.rstrip('\n')

def process_messages(messages):
    # 组织system2 input的时候，返回工具的时候最后一个role是user，但前面仍需要处理掉think
    # if not messages or messages[-1]['role'] != 'assistant':
    #     return messages
    messages = copy.deepcopy(messages)
    new_messages = []
    last_index = len(messages) - 1

    for i, message in enumerate(messages):
        if message['role'] == 'assistant' and i != last_index:
            message['content'] = re.sub(r'<think>.*?</think>', '', message['content'], flags=re.DOTALL).strip("\n")
    return messages

class Context(ABC):
    def __init__(self, idx: int, item: dict, **kwargs):
        self.idx = idx
        self.create_context(item, **kwargs)

        self.chain = []
        self.evaluation = None

        # 多轮对话的形式
        self.multiturn_agent = kwargs.get('multiturn_agent', True)

        self.tolerance = TOOL_ERROR_TOLERANCE

        # 记录工具使用次数
        self.tool_usage = {}
        for tool_name in kwargs.get('tool_names', []):
            self.tool_usage[tool_name] = 0

    @property
    def is_terminal(self) -> bool:
        """check if the context is terminal"""
        # tolerance mechanism
        if self.chain[-1].end and self.tolerance > 0:
            last_node = self.chain[-1]
            # 参数什么都合法的情况
            if last_node.tool is not None and last_node.tool['tool_name'] in TOLERANCE_TOOLS:
                if isinstance(last_node.stop_reason, str):
                    if last_node.tool_json is None:
                        last_node.tool_json = last_node.tool['tool_args']
                    if last_node.tool_output is None:
                        last_node.tool_output = last_node.stop_reason

                    last_node.valid = True
                    last_node.end = False
                    self.tolerance -= 1
                else:
                    assert False, f"{last_node.tool['tool_name']} is not supported, please debug here"
                    
        return self.chain[-1].end

    def create_context(self, item, **kwargs):
        """receive the item and create the context"""
        self.data_source = item['id'].split('_')[0]
        self.id = item['id']
        self.question = item['question']
        self.answers = item['answers'] # list
        self.prompt_config = item['prompt']
        self.item = item

    def prepare_system2_input(self, mask_pre_think=True):
        """prepare the input in the message format for the inference"""
        if len(self.chain) == 0:
            messages = [
                {'role': 'system', 'content': self.prompt_config['system2']['system']},
                {'role': 'user', 'content': self.question if self.prompt_config['system2']['user'] is None else self.prompt_config['system2']['user'] + self.question}
            ]
        else:
            # organize the input for system2 along with the chain
            # only invoking tools will go into this part
            cur_node = self.chain[-1]
            if not self.multiturn_agent:
                # assert isinstance(cur_node.system2_message, str), f"The input text should be a string, but received {type(cur_node.message).__name__}."
                messages = cur_node.system2_message # + cur_node.system2_output + cur_node.system1_output
                if messages[-1]['role'] != 'assistant':
                    messages.append({'role': 'assistant', 'content': cur_node.system2_output + cur_node.system1_output})
                else:
                    messages[-1]['content'] += cur_node.system2_output + cur_node.system1_output
            else:
                messages = cur_node.system2_message
                messages.append({'role': 'assistant', 'content': cur_node.system2_output})
                if cur_node.tool_role_message is not None:
                    messages.append(cur_node.tool_role_message)

        if mask_pre_think:
            messages = process_messages(messages)
        
        return messages

    def prepare_system1_input(self, tools: dict, tokenizer, max_prompt_length: int=8192, empty_mode:bool = False, readpage: bool=False) -> List:
        """prepare the input in the message format for the inference"""
        pre_message = [
            {"role": "system", "content": self.prompt_config['system1']['system']},
            {"role": "user", "content": self.prompt_config['system1']['user'].format(question=self.question, observation='')}
        ]
        pre_len = len(tokenizer.encode(tokenizer.apply_chat_template(pre_message, tokenize=False, add_generation_prompt=True), truncation=False)) + 128  # 128 for some special tokens
        assert max_prompt_length - pre_len > 0, f"max_prompt_length is too small (pre_len: {pre_len} | max_prompt_length: {max_prompt_length})\npre_message:\n{pre_message}"

        node = self.chain[-1]
        tool_name = node.tool['tool_name']
        tool_observation = tools[tool_name].observation(node.tool, node.tool_json, node.tool_output, empty_mode=empty_mode, readpage=readpage, max_observation_length=max_prompt_length - pre_len, tokenizer=tokenizer)

        self.tool_usage[tool_name] = self.tool_usage.get(tool_name, 0) + 1
        
        if tool_name in ['PythonInterpreter']:
            encoded = tokenizer.encode(tool_observation, max_length=max_prompt_length - pre_len, truncation=True)
            tool_observation = tokenizer.decode(encoded)
            assert isinstance(tool_observation, str), f"tool_observation must be str, but got {type(tool_observation)}"
            return tool_observation # str

        if not readpage:
            # check the length of the prompt
            encoded = tokenizer.encode(tool_observation, max_length=max_prompt_length - pre_len, truncation=True)
            tool_observation = tokenizer.decode(encoded)

            if empty_mode:
                return tool_observation # str

            user_input = self.prompt_config['system1']['user'].format(question=self.question, observation=tool_observation)
            message = [
                {"role": "system", "content": self.prompt_config['system1']['system']},
                {"role": "user", "content": user_input}
            ]
            return message
        
        else:
            message_list = []
            for obs in tool_observation:
                # check the length of the prompt
                encoded = tokenizer.encode(obs['content'], max_length=max_prompt_length - pre_len, truncation=True)
                obs['content'] = tokenizer.decode(encoded)

                user_input = self.prompt_config['system1']['user'].format(question=self.question, observation=obs['content'])
                message = [
                    {"role": "system", "content": self.prompt_config['system1']['system']},
                    {"role": "user", "content": user_input}
                ]
                message_list.append(message)
                obs['message'] = message
            
            return {
                'message_list': message_list,
                'tool_observation': tool_observation,
                'response_list': [None] * len(message_list)
            }

    def parse_system1_output(self, message, system1_output: str, parse_output_func):
        """parse the output of system1"""
        node = self.chain[-1]
        system1_output = parse_output_func(system1_output)
        node.system1_message = message
        node.system1_response = system1_output

    def system2_data(self, tokenizer):
        response = ""
        for node in self.chain:
            response += node.system2_output
            if node.system1_output is not None:
                response += node.system1_output
        prompt = tokenizer.apply_chat_template(self.chain[0].system2_message[:2], tokenize=False, add_generation_prompt=True)
        response = response + tokenizer.eos_token
        return {
            "system2_prompt": prompt,
            # "prompt": self.chain[0].system2_message[:2],
            "system2_response": response,
            # answer_part
            'answer_tag_part': self.chain[-1].answer_tag_part,
            'boxed_answer': self.chain[-1].boxed_answer,
            # token
            "system2_prompt_tensor": tokenizer.encode(prompt, return_tensors="pt"),
            "system2_response_tensor": tokenizer.encode(response, return_tensors="pt"),
            # others
            "id": self.id,
            'task': self.item['task'],
            "question": self.question,
            "answer": self.answers,
            "stop_reason": self.chain[-1].stop_reason,
            "evaluation": self.evaluation,
        }
    
    def system1_data(self) -> List:
        data = []
        for idx, node in enumerate(self.chain):
            if node.system1_message is not None:
                data.append({
                    "prompt": node.system1_message,
                    "response": node.system1_output,
                    # answer_part
                    'answer_tag_part': self.chain[-1].answer_tag_part,
                    'boxed_answer': self.chain[-1].boxed_answer,
                    "id": self.id,
                    "question": self.question,
                    "answer": self.answers,
                    "step": idx,
                    "evaluation": self.evaluation,
                })
        return data

    def exceed_max_depth(self):
        node = self.chain[-1]
        node.end = True
        node.stop_reason = "Exceeded Maximum Depth"
    
    def get_multi_turn_rl_data(self, tokenizer, readpage=False, system2_think=True, system1_think=False, mask_pre_think=True):
        # system2
        response = [{'role': 'user', 'content': "this part should be discarded."}]
        for node in self.chain:
            if node.system2_output is not None: # 超长了就会出现None
                response.append({'role': 'assistant', 'content': node.system2_output})
            if node.tool_role_message is not None:
                response.append(node.tool_role_message)
        prompt = tokenizer.apply_chat_template(self.chain[0].system2_message[:2], tokenize=False, add_generation_prompt=True, enable_thinking=system2_think)
        
        original_response = copy.deepcopy(response[1:])
        if mask_pre_think:
            response = process_messages(response)
        
        # system 1
        system1_data = []
        for idx, node in enumerate(self.chain):
            if node.system1_message is not None:
                if not readpage:
                    system1_prompt = tokenizer.apply_chat_template(node.system1_message, tokenize=False, add_generation_prompt=True, enable_thinking=system1_think) # TODO: 对于think模式需要检查
                    system1_response = node.system1_output + tokenizer.eos_token
                    system1_data.append({
                        "system1_messages": node.system1_message,
                        "system1_prompt": system1_prompt,
                        "system1_response": system1_response,
                        "system1_output_format": node.system1_output_format,
                    })
                else:
                    for mess, o, fmt in zip(node.system1_message, node.system1_output, node.system1_output_format, strict=True):
                        system1_prompt = tokenizer.apply_chat_template(mess, tokenize=False, add_generation_prompt=True, enable_thinking=system1_think)
                        system1_response = o + tokenizer.eos_token
                        system1_data.append({
                            "system1_messages": mess,
                            "system1_prompt": system1_prompt,
                            "system1_response": system1_response,
                            "system1_output_format": fmt,
                        })

        return {
            "system2_messages": self.chain[0].system2_message[:2],
            "system2_prompt": prompt,
            "system2_response": process_text(tokenizer.apply_chat_template(response, tokenize=False, add_generation_prompt=False)),
            "system2_response_message": response[1:] if self.multiturn_agent else None,
            "system2_original_response_message": original_response,
            # answer_part
            'answer_tag_part': self.chain[-1].answer_tag_part,
            'boxed_answer': self.chain[-1].boxed_answer,
            # system1 data
            "system1": system1_data if system1_data else None,
            # others
            "id": self.id,
            'data_source': self.item['task'].split('/')[-1],
            'task': self.item['task'],
            "question": self.question,
            "answer": self.answers,
            "stop_reason": self.chain[-1].stop_reason,
            "whole_chain": self.chain_data(),
            "tolerance": f"{self.tolerance}/{TOOL_ERROR_TOLERANCE}",
            "tolerance_tools": ",".join(TOLERANCE_TOOLS),
            "tool_usage": self.tool_usage
        }

    def get_rl_data(self, tokenizer):
        # system2
        response = ""
        for node in self.chain:
            if node.system2_output is not None: # 超长了就会出现None
                response += node.system2_output
            if node.system1_output is not None:
                response += node.system1_output
        prompt = tokenizer.apply_chat_template(self.chain[0].system2_message[:2], tokenize=False, add_generation_prompt=True)
        response = response + tokenizer.eos_token

        # system 1
        system1_data = []
        for idx, node in enumerate(self.chain):
            if node.system1_message is not None:
                system1_prompt = tokenizer.apply_chat_template(node.system1_message, tokenize=False, add_generation_prompt=True)
                system1_response = node.system1_output + tokenizer.eos_token
                system1_data.append({
                    "system1_messages": node.system1_message,
                    "system1_prompt": system1_prompt,
                    "system1_response": system1_response,
                    # token
                    # "system1_prompt_tensor": tokenizer.encode(system1_prompt, add_special_tokens=False, return_tensors="pt"),
                    # "system1_response_tensor": tokenizer.encode(system1_response, add_special_tokens=False, return_tensors="pt"),
                })

        return {
            "system2_messages": self.chain[0].system2_message[:2],
            "system2_prompt": prompt,
            # "prompt": self.chain[0].system2_message[:2],
            "system2_response": response,
            # answer_part
            'answer_tag_part': self.chain[-1].answer_tag_part,
            'boxed_answer': self.chain[-1].boxed_answer,
            # token
            # "system2_prompt_tensor": tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt"),
            # "system2_response_tensor": tokenizer.encode(response, add_special_tokens=False, return_tensors="pt"),
            # system1 data
            "system1": system1_data if system1_data else None,
            # others
            "id": self.id,
            'data_source': self.item['task'].split('/')[-1],
            'task': self.item['task'],
            "question": self.question,
            "answer": self.answers,
            "stop_reason": self.chain[-1].stop_reason,
            "evaluation": self.evaluation,
        }

    def chain_data(self) -> List:
        chain = []
        for idx, node in enumerate(self.chain):
            node_dict = asdict(node)
            node_dict['id'] = self.id
            node_dict['question'] = self.question
            node_dict['answer'] = self.answers
            node_dict['step'] = idx
            # node_dict['evaluation'] = self.evaluation
            chain.append(node_dict)
        return chain
