"""
Preprocess dataset for qaq task
"""

import json
import os
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer
import random
random.seed(1234)

def get_prompt_length(messages, tokenizer):
    prompt = ''
    for message in messages:
        content = message['content']
        if message['role'] == 'system':
            prompt += f'<|im_start|>system\n{content}\n<|im_end|>\n'
        elif message['role'] == 'user':
            prompt += f'<|im_start|>user\n{content}\n<|im_end|>\n'
        elif message['role'] == 'assistant':
            prompt += f'<|im_start|>assistant\n{content}\n<|im_end|>\n'
        else:
            raise ValueError(f"Invalid role: {message['role']}")
    tokens = tokenizer.encode(prompt)
    return len(tokens)


QAQ_TEMPLATE_ONETURN = """# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_descs}
</tools>

For each function call, think step by step and return a json object with function name and arguments within <tool_call></tool_call><tool_response></tool_response> XML tags:
<think>
your thoughts here
</think>
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
<tool_response>
{"name": <function-name>, "content": <args-json-object>}
</tool_response>"""


QAQ_TEMPLATE_MULTITURN = """# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_descs}
</tools>

For each function call, think step by step and return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<think>
your thoughts here
</think>
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>"""


TEMPLATE_GT = """Put your final answer within \\boxed{{}}."""
TEMPLATE_TOOL_CALL = """You have at most {max_tool_call} tool calls available."""

def make_prefix(question, tool_schemas, template_type):
    tools = '\n'.join([json.dumps(tool, ensure_ascii=False, indent=2) for tool in tool_schemas])
    # NOTE: also need to change reward_score/qaq_fncall.py
    if template_type == 'base':
        """This works for any base model"""
        prefix = r"""A conversation between User and Assistant. The user asks a question, and the assistant solves it by calling one or more of the following tools.""" + f"""
<tools>
{tools}
</tools>""" + r"""

The assistant starts with one or more cycles of (thinking about which tool to use -> performing tool call -> waiting for tool response), and ends with (thinking about the answer -> answer of the question). The thinking processes, tool calls, tool responses, and answer are enclosed within their tags. There could be multiple thinking processes, tool calls, tool call parameters and tool response parameters.

Example response:
<think> thinking process here </think>
<tool_call>
{"name": "tool name here", "arguments": {"parameter name here": parameter value here, "another parameter name here": another parameter value here, ...}}
</tool_call>
<tool_response>
{"name": "tool name here", "content": {"result name here": result value here, "another result name here": another result value here, ...}}
</tool_response>
<think> thinking process here </think>
<tool_call>
{"name": "another tool name here", "arguments": {...}}
</tool_call>
<tool_response>
{"name": "another tool name here", "content": {...}}
</tool_response>
(more thinking processes, tool calls and tool responses here)
<think> thinking process here </think>
<answer> answer here </answer>""" + f"""

User: {question}

Assistant: <think>"""
    else:
        raise NotImplementedError
    return prefix

def make_messages(prompt, tool_schemas, template_type, max_llm_call):
    tool_descs = '\n'.join([json.dumps(tool, ensure_ascii=False, indent=2) for tool in tool_schemas])
    if template_type == 'nous':
        system = QAQ_TEMPLATE_ONETURN.replace('{tool_descs}', tool_descs)
        system += '\n' + TEMPLATE_TOOL_CALL.replace('{max_tool_call}', max_llm_call - 1)
        messages = [
            {"role": "system", "content": system},
            {"role": "user", "content": prompt},
        ]
    elif template_type == 'nous_gt':
        system = QAQ_TEMPLATE_MULTITURN.replace('{tool_descs}', tool_descs)
        system += '\n' + TEMPLATE_GT
        system += '\n' + TEMPLATE_TOOL_CALL.replace('{max_tool_call}', str(max_llm_call - 1))
        messages = [
            {"role": "system", "content": system},
            {"role": "user", "content": prompt},
        ]
    else:
        raise NotImplementedError
    return messages


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--template_type', type=str, default='nous_gt') # nous nous_gt
    parser.add_argument('--version', type=int, default=1)
    # fake_fncall fake_fncall2
    # tables tables_global tables_local tables_local_exec
    # fv_tables fv_tables_global
    parser.add_argument('--source_type', type=str, default='fv_tables')
    parser.add_argument('--max_prompt_len', type=int, default=8*1024)
    parser.add_argument('--max_llm_call', type=int, default=2)
    parser.add_argument('--max_val_num', type=int, default=100)

    args = parser.parse_args()

    version = f'_v{args.version}' if args.version != 1 else ''
    local_dir = f'data/{args.source_type}{version}'
    train_dataset_path = f'data_raw/qaq/tables/{args.source_type}_train.jsonl'
    test_dataset_path = f'data_raw/qaq/tables/{args.source_type}_val.jsonl'
    data_source = f'qaq_{args.source_type}{version}'

    train_dataset_list = [json.loads(line) for line in open(train_dataset_path, 'r')]
    test_dataset_list = [json.loads(line) for line in open(test_dataset_path, 'r')]

    model_directory = "/cpfs01/shared/Group-m6/feihu.hf/rl_base_model/32b.stage2--cptv9-base100w-cpt32k-0218_Ma_Co_STE_Rea_Age_Gen_dgb-32B.qwen2.5B-bf16-mp8-pp1-lr-7e-6-minlr-7e-7-bs-128-gpus-128-seqlen-32768-step16784"
    tokenizer = AutoTokenizer.from_pretrained(model_directory)

    def prepare_unieval_data(dataset_list):
        unieval_dataset = []
        for data in tqdm(dataset_list):
            prompt = data['prompt']
            tool_schemas = data['tool_schemas']
            messages = make_messages(prompt, tool_schemas, template_type=f'{args.template_type}{version}', max_llm_call=args.max_llm_call)
            if get_prompt_length(messages, tokenizer) > args.max_prompt_len:
                continue
            unieval_data = {
                "task": "agent/qaq",
                "messages": messages,
                "answer": data.get('answer', None),
                "tool_schemas": tool_schemas,
                "tool_calls": data.get('tool_calls', None),
            }
            unieval_dataset.append(unieval_data)
        print(f'{len(unieval_dataset)}/{len(dataset_list)} are valid')
        return unieval_dataset

    train_dataset = prepare_unieval_data(train_dataset_list)
    test_dataset = prepare_unieval_data(test_dataset_list)

    with open(os.path.join(local_dir, 'train.jsonl'), 'w', encoding='utf-8') as f:
        for data in train_dataset:
            f.write(json.dumps(data, ensure_ascii=False) + '\n')

    if len(test_dataset) > args.max_val_num:
        test_dataset = random.sample(test_dataset, args.max_val_num)
    with open(os.path.join(local_dir, 'test.jsonl'), 'w', encoding='utf-8') as f:
        for data in test_dataset:
            f.write(json.dumps(data, ensure_ascii=False) + '\n')
