import json
import requests
import os
import time
from tqdm import tqdm
import json
import requests
import os
import time
from tqdm import tqdm
from agents.service_utils import *
from agents.bge_utils import *

# pip install FlagEmbedding
from reward.reward_model_rule import *
from agents.rody_calling_mock import *

api_calling_object = MockAPIHandler()

reward_model = RewardModel()

MASTER_TOOLBASE = "mutli_agent_dymanic_planning_vearch_master"
EXPERT_TOOLBASE = "mutli_agent_dymanic_expert_knowledge_vearch"
RODY_TOOLBASE = "mutli_agent_dymanic_planning_vearch_subagent_rody"
TOOLBENCH_TOOLBASE = "mutli_agent_dymanic_planning_vearch_subagent_toolbench_v2"

RODY_TOOL_LIST = [
    'check_order_report', 'check_deposit_refund', 'search_order_code', 'check_shop_expenses',
    'check_shop_qualifications',
    'agent_user_growth',
    'search_service_code',
    'search_promotion',
    'search_coupon',
    'check_audit_status',
    'search_binding_card_status',
    'search_payment_method',
    'competitive_goods_analysis'
]

MASTER_PROMPT = """Role：
You are a customer service expert of an e-commerce platform(Corporation_A), specializing in selecting appropriate tools and agents to solve user problems based on user questions. Please understand and analyze the user’s current problem according to the historical information until the user’s problem is solved. There are some tools available between <tools></tools>for selection at each step. 

Specialized Retrieval Tools:
<tools>
{"name": "tool_retrievals_knowledge", "description": "Vertical knowledge base search tool (e-commerce merchant operations context). Identifies relevant information based on user queries.", "arguments": {"intention": "user's current intention or query"}}
{"name": "tool_retrievals_API_shop", "description": "E-commerce platform API lookup. Retrieves relevant APIs from the API knowledge base using intent analysis.", "arguments": {"intention": "user's current intention or query"}}
{"name": "tool_retrievals_API_general", "description": "General API lookup. Retrieves relevant APIs from the API knowledge base using intent analysis.", "arguments": {"intention": "user's current intention or query"}}
</tools>

Problem Resolution Framework:
1. Question Types & Response Protocols:
You may encounter different types of questions. The types of questions and the required output formats are as shown below:
---
- Math problems: 
    - Provide direct solutions to numerical queries.
    - Output in the following format(Provide the numerical answer directly after ‘</think>’, without units or any irrelevant characters): <think>...</think>Final numeric answer
- API scheduling problems:
    - The APIs are divided into e-commerce platform APIs and general APIs.
    - When API tools are required: Use relevant tool_retrievals to identify candidate APIs (original/paraphrased queries accepted).
    - Output the API call results in the following format: <tool_call>{"name": "API_name", “arguments”: {"key1":["value11", "value12"], "key2":["value21", "value22"]..}}</tool_call>
    - Some solutions require sequential API calls, but you can just call only one API at each step. Use prior outputs as inputs for subsequent calls.
- Q&A problems: 
    - Engage directly in casual conversations (greetings/jokes/daily topics).
    - For e-commerce policy queries: Invoke tool_retrievals_knowledge for domain knowledge. Respond based on retrieved content.

2.Tool/API Selection Guidelines:
    - The results of the previous Tool/API call will be returned in the format <tool_response>...</tool_response>.
    - The response format for API dispatching results is: “Pass calling … Results are as follows: …”.This result should generally be output to the user as-is to indicate the content of the API call. Additionally, if multiple API calls are involved, all relevant API call results must be merged and presented together to the user.
    - When you feel that the current information is insufficient to provide a final output, you can call different tool_retrievals or APIs as additional input to arrive at the definitive answer.
    - Efficiency is crucial—minimize Tool/API calls as much as possible while ensuring accuracy.

3. Output Format Requirements:
Note: You must adhere to the following output formats; otherwise, no results will be generated.
- When you determine that additional Tool/API calls are needed (Tool call format: API/tool_retrievals): <think>Thought process</think><tool_call>{"name": "tool_name", "arguments": {"param": "value"/["value"]}}</tool_call>
- When you believe the current conclusion is sufficient to return to the user: <think>Thought process</think>Output answer(if math problems, output final numeric answer; If it is an API-related issue and does not involve multiple API calls, output the content from <tool_response> exactly as it is.)"""


def open_ai_api_question(query, prompt, context_history=[], model="gpt-4o"):
    context = list()
    # make history format
    for i in range(len(context_history)):
        q, a = context_history[i]
        context.append({'role': 'user', 'content': q})
        context.append({'role': 'assistant', 'content': a})
    messages = [{'role': 'system', 'content': prompt}, ] + context + [{'role': 'user', 'content': query}, ]
    output = call_llm_messages(model, messages)
    return output


def llm_qwen_local(query, model, prompt='', context_history=[]):
    context = list()
    # make history format
    for i in range(len(context_history)):
        q, a = context_history[i]
        context.append({'role': 'user', 'content': q})
        context.append({'role': 'assistant', 'content': a})
    messages = [{'role': 'system', 'content': prompt}, ] + context + [{'role': 'user', 'content': query}, ]
    output = call_qwen_messages(model, messages)
    return output


def verify_end(query, action):
    """
    Determine whether the current Agent return code is the final answer, and get the returned code to determine the tool for the next round of calls
    """
    response_tools = ''

    if len(re.findall(r'{"|\'name"|\':.*?}}', action, re.S)) > 0 and 'Pass calling' not in action:
        try:
            match_result = eval(action.replace('<tool_call>', '').replace('</tool_call>', ''))
            tool_name = match_result['name']
            arguments = match_result['arguments']
            intention = query
            if isinstance(arguments, dict) and "intention" in arguments:
                intention = arguments['intention']
            # print(tool_name, arguments)
            code = 1

            if tool_name == 'tool_retrievals_knowledge':
                response_tools = call_tool_retrieval(intention, "expert", EXPERT_TOOLBASE)
            elif tool_name == 'tool_retrievals_API_shop':
                response_tools = call_tool_retrieval(intention, "rody", RODY_TOOLBASE)
            elif tool_name == 'tool_retrievals_API_general':
                response_tools = call_tool_retrieval(intention, "toolbench", TOOLBENCH_TOOLBASE)
            else:
                data = {}
                data['api_name'] = tool_name
                data['parameter'] = arguments
                if tool_name in RODY_TOOL_LIST:
                    response_tools = api_calling_object.call_api(query, data)
                else:
                    response_tools = api_calling_object.call_api_common(data)
                print(response_tools)
        except Exception as e:
            print(e)
            code = -1
    else:
        code = 0
    return code, response_tools


def model_ans_generator(q, model, history):
    output = ''
    if model in ["Qwen-14b", "Qwen-32b"]:
        output = llm_qwen_local(q, model, MASTER_PROMPT, context_history=history)
    else:
        output = open_ai_api_question(q, MASTER_PROMPT, context_history=history, model=model)
    return output


def generate_ans(q, model, history=[], max_round=5):
    step = 1
    output = model_ans_generator(q, model=model, history=history)
    action = output.split('</think>')[-1]
    code, retrieved_tools = verify_end(q, action)
    history = [(q, action)]
    print("round num 1:", code, ' output:\n', output)

    while step < max_round and code > 0:
        response = "<tool_response>" + retrieved_tools + "</tool_response>"
        output = model_ans_generator(response, model=model, history=history)
        action = output.split('</think>')[-1]
        code, retrieved_tools = verify_end(q, action)
        history.append((response, action))
        # print('\nhistory:', history[1:])
        step += 1
        print("round num ", step, ", code is:", code, ' output:\n', output)

    if code == 0:
        return output, step, history
    elif code > 0:
        ans = 'Agent timeout planning failed! Fail to get an answer'
        return ans, step, history
    else:
        ans = 'Agent output parsing failed!'
        return ans, step, history


def write_json(target_list, target_file_name):
    with open(target_file_name, 'w', encoding='utf-8') as file:
        for item in target_list:
            file.write(json.dumps(item, ensure_ascii=False) + "\n")


def filter_cor(test_dataset):
    test_dataset_new = []
    for data in test_dataset:
        if 'agent' in data:
            agent = data['agent'].replace('_agent', '')
        else:
            agent = data['task_type'].replace('_agent', '')
        if agent not in ['cooperate']:
            continue
        test_dataset_new.append(data)

    return test_dataset_new


def main(test_file, response_file, result_file, model, max_round=5, save_step=5):
    try:
        with open(response_file) as f:
            lines = f.readlines()
        test_dataset = [eval(i) for i in lines]
    except FileNotFoundError:
        with open(test_file) as f:
            lines = f.readlines()
        test_dataset = [eval(i) for i in lines]

    test_dataset = filter_cor(test_dataset)
    print('len of data: ', len(test_dataset))

    print(response_file)
    for idx, dict_item in tqdm(enumerate(test_dataset), total=len(test_dataset)):
        if 'response' in dict_item and dict_item['response'] != '':
            print('pass')
            continue
        q, his = dict_item['query'], dict_item['history']

        print(f"\n\n\n---\n{idx}, query:{q}")
        ans, step, history = generate_ans(q, model=model, history=his, max_round=max_round)
        dict_item['round_num'] = step
        dict_item['response'] = ans
        dict_item['history_inference'] = history
        # print(ans)
        if idx % save_step == 0:
            write_json(test_dataset, response_file)
    write_json(test_dataset, response_file)

    print('calling .... ')
    for idx, item in tqdm(enumerate(test_dataset), total=len(test_dataset)):
        if 'response' not in item or item['response'] is None or item['response'] == '':
            item['baseline_answer'] = None
            continue

        action = str(item['response'].split('</think>')[-1])
        item['baseline_answer'] = action
        try:
            api_info = eval(action.replace('<tool_call>', '').replace('</tool_call>', ''))
            if isinstance(api_info, list):
                api_info = api_info[0]
            data = dict()
            data['api_name'] = api_info['name']
            data['parameter'] = api_info['arguments']
            if api_info['name'] in RODY_TOOL_LIST:
                results = api_calling_object.call_api(item['query'], data)
            else:
                results = api_calling_object.call_api_common(data)
            item['baseline_answer'] = results
        except Exception as e:
            print(e)
            continue

    # Generate inference score
    print(result_file)
    for idx, data in tqdm(enumerate(test_dataset), total=len(test_dataset)):
        query_type, query, pred_answer, real_answer = data['task_type'], data['query'], data['baseline_answer'], data[
            'answer']
        query_type = query_type.replace('_agent', '')
        if pred_answer is None:
            data['reward'] = -1
        else:
            reward_score = reward_model.get_reward(query_type, query, pred_answer, real_answer)
            data['reward'] = reward_score
            print('query_type:', query_type)
            print('pred_answer:', pred_answer)
            print('real_answer:', real_answer)
            print('\n***\nreward_score:', reward_score)
        if idx % save_step == 0:
            write_json(test_dataset, result_file)

    write_json(test_dataset, result_file)


if __name__ == '__main__':
    test_file = "test_datasets/multi_agent_RL_test_datasets.json"
    main_dir = "../pred_results/llm_baselines/"
    model_lst = ["gpt-4o", "Qwen-14b"]

    for model in model_lst:
        print(model, 'start!!!!')
        response_file = main_dir + f"baseline_result_{model}_0514.json"
        result_file = main_dir + f"baseline_reward_{model}_0514.json"
        main(test_file, response_file, result_file, model=model, save_step=5)
        print('\n' * 10, '*' * 10, '\n')