# -*- coding:utf-8 -*-
"""
Multi-agent evaluation script
"""
import argparse
import os
import sys
from collections import OrderedDict

src = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.insert(0, src)  # insert tutorials to PYTHONPATH
print(sys.path)

from model import ActorModel
# from safetensors.torch import load_file
from config.config import *
import torch
from agents.bge_utils import *
from agents.agent_call import *
from reward.reward_model_rule import RewardModel

MASTER_MEMORY_JSON_PATH = '../memory/master/agent_memory.json'
MASTER_MEMORY_FAISS_PATH = '../memory/master/agent_memory.index'
MATH_MEMORY_JSON_PATH = '../memory/math/agent_memory.json'
MATH_MEMORY_FAISS_PATH = '../memory/math/agent_memory.index'
RODY_MEMORY_JSON_PATH = '../memory/rody/agent_memory.json'
RODY_MEMORY_FAISS_PATH = '../memory/rody/agent_memory.index'
EXPERT_MEMORY_JSON_PATH = '../memory/expert/agent_memory.json'
EXPERT_MEMORY_FAISS_PATH = '../memory/expert/agent_memory.index'
TOOLBENCH_MEMORY_JSON_PATH = '../memory/toolbench/agent_memory.json'
TOOLBENCH_MEMORY_FAISS_PATH = '../memory/toolbench/agent_memory.index'

MODEL_MEMORY = {
    'math': [MATH_MEMORY_JSON_PATH, MATH_MEMORY_FAISS_PATH],
    'rody': [RODY_MEMORY_JSON_PATH, RODY_MEMORY_FAISS_PATH],
    'expert': [EXPERT_MEMORY_JSON_PATH, EXPERT_MEMORY_FAISS_PATH],
    'toolbench': [TOOLBENCH_MEMORY_JSON_PATH, TOOLBENCH_MEMORY_FAISS_PATH]
}

MODEL_CUDA_DICT = {
    'master': 'cuda:0',
    'math': 'cuda:1',
    'rody': 'cuda:2',
    'expert': 'cuda:3',
    'toolbench': 'cuda:4'
}


def load_model(model_path, weight_path, prompt, tool_database, model_name):
    print(model_name, '.model_path:', model_path, ';weight_path:', weight_path)
    cuda_device = MODEL_CUDA_DICT[model_name]
    model = ActorModel(model_path, prompt, tool_database, None, model_name, require_grad=False, device=cuda_device)

    if weight_path is not None:
        weights = torch.load(weight_path, map_location=torch.device('cuda'))
        new_weights = OrderedDict()
        for (k, v) in weights.items():
            new_k = 'model.' + k
            new_weights[new_k] = v.to(cuda_device)
        model.load_state_dict(new_weights)
    return model


def make_conversation_format(query, history, prompt, role):
    conversations = list()
    system_item = {'role': 'system', 'content': prompt}
    conversations.append(system_item)
    for (q, a) in history:
        conversations.append({'role': 'user', 'content': q})
        if a is None:
            continue
        conversations.append({'role': 'assistant', 'content': a})
    conversations.append({'role': role, 'content': query})
    return conversations


def llm_pipeline_qwen(query, history, model, prompt, tokenizer, role, args):
    _input = make_conversation_format(query, history, prompt, role)
    input_text = tokenizer.apply_chat_template(_input, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([input_text], return_tensors="pt")
    input_ids = inputs.input_ids
    generate_kwargs = {
        "min_length": -1,
        "max_new_tokens": args.max_new_token,
        "temperature": args.temperature,
        "repetition_penalty": args.repetition_penalty,
        "do_sample": args.do_sample,
        "bos_token_id": tokenizer.bos_token_id,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id
    }
    generate_ids, attention_mask = model.generate(input_ids, **generate_kwargs)
    input_len = input_ids.shape[1]
    output_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generate_ids)]
    output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]

    return output, generate_ids, input_len, attention_mask, input_ids


def agent_generate_one_step(query, history, model, prompt, tokenizer, args, role):
    # role = 'user'
    output, generate_ids, input_len, attention_mask, \
    input_ids = llm_pipeline_qwen(query, history, model, prompt, tokenizer, role, args)

    tmp = output.split('</think>')
    if len(tmp) < 2:
        thought, action = output, output
    else:
        thought = tmp[0].strip('<think>')
        action = tmp[1].strip('\n')

    if role == 'tool':
        new_query = '<tool_response>' + query + '</tool_response>'
        history.append((new_query, action))
    else:
        history.append((query, action))

    return output, thought, action, generate_ids, input_len, attention_mask, input_ids, history


def verify_end(model_type, output, action):
    """
    Determine whether the current Agent return code is the final answer, and get the returned code to determine the tool for the next round of calls
    """
    # Format Points Reward
    format_reward = 0
    # Determine whether the think format is correct
    think_match = re.findall(r'<think>(.*?)</think>', output, re.S)
    # There is only one pair of think tags, and there is content in them
    if think_match is None or len(think_match) > 1:
        format_reward -= 5

    tool_name = ''
    arguments = ''
    tool_call_format = -0.5

    if len(re.findall(r'{"|\'name"|\':.*?}}', action, re.S)) > 0 and 'Pass calling' not in action:
        # Determine whether the tool_call format is correct
        tool_call_match = re.findall(r'<tool_call>(.*?)</tool_call>', action, re.S)
        if action.startswith('<tool_call>'):
            tool_call_format += 0.25
        if action.endswith('</tool_call>'):
            tool_call_format += 0.25

        try:
            match_result = eval(action.replace('<tool_call>', '').replace('</tool_call>', ''))
            tool_name = match_result['name']
            arguments = match_result['arguments']

            if tool_name == 'tool_intention':
                code = 1
            elif tool_name == 'tool_retrievals':
                code = 2
            else:
                # Call sub-agent or api
                if model_type == 'master':
                    # If the agent is called this time, it needs to be within the sub-agent range
                    if tool_name in ('rody_agent', 'expert_agent', 'math_agent', 'toolbench_agent'):
                        code = 3
                    else:
                        code = -1
                else:
                    code = 3
        except Exception as e:
            print(e)
            code = -1
            tool_name = 'error_tool'
    else:
        # The current output is the master or sub-agent direct output
        code = 0
        tool_call_format = 0

    return tool_name, code, format_reward, arguments


def parse_recall_memory(recalled_memory):
    tools = list()
    for memory in recalled_memory:
        plans = memory["plan"]
        for plan in plans:
            if plan not in ALL_TOOL_DESC:
                continue
            item = {"name": plan, "description": ALL_TOOL_DESC[plan], "arguments": {}}
            tools.append(json.dumps(item))
    tool_append = '\n'.join(tools)
    memory_append = json.dumps(recalled_memory, ensure_ascii=False)
    return tool_append, memory_append


def get_response_from_subagent(args, actor_model, tokenizer, query, history_ori, system_prompt, job_desc, model_name, max_round_num=5):
    # Sample a path from Q->A, take out the sub-agent involved, and perform sampling and update inside the sub-agent
    actor_model.eval()

    print('calling sub-agent', model_name)
    print('current job desc is: ', job_desc)

    round_num = 1
    arguments = ''

    # Sampling the actions taken in each round
    action_list = list()

    history = history_ori.copy()
    print('history: ', history)
    # Memory call
    recall_memory = []
    print('recalled memory: ', recall_memory)
    # Put the recalled memory content into prompt
    tool_append, memory_append = parse_recall_memory(recall_memory)
    agent_memory_prompt = system_prompt.format(tool_append=tool_append, memory_append=memory_append)

    output, thought, action, generated_ids, input_len, attention_mask, \
    input_ids, history = agent_generate_one_step(job_desc, history, actor_model, agent_memory_prompt, tokenizer,
                                                 args, 'user')
    print(model_name, ' round_num', round_num, ':', output)

    tool_name, code, format_reward, arguments = verify_end('sub-agent', output, action)
    if code > 0:
        action_list.append((tool_name, ''))
    else:
        if action == '':
            action = ' '
        action_list.append(('output', action))

    while code > 0 and round_num <= max_round_num:
        if code == 1:
            round_num += 1
            intention = call_intention_service(query, history)
            print('call intention response: ', intention)

            output, thought, action, generated_ids, input_len, attention_mask, \
            input_ids, history = agent_generate_one_step(intention, history, actor_model, agent_memory_prompt,
                                                         tokenizer, args, 'tool')
            print(model_name, ' round_num', round_num, ':', output)

            tool_name, code, format_reward, arguments = verify_end('sub-agent', output, action)
            if code > 0:
                action_list.append((tool_name, ''))
            else:
                if action == '':
                    action = ' '
                action_list.append(('output', action))
        elif code == 2:
            round_num += 1
            if 'intention' in arguments:
                try:
                    arg_intention = arguments['intention']
                    if isinstance(arg_intention, list):
                        arg_intention = arg_intention[0]
                except Exception as e:
                    arg_intention = arguments['intention']
            else:
                arg_intention = query
            retrieved_tools = call_tool_retrieval(arg_intention, model_name, actor_model.tool_database)
            print('call tool retrieval response: ', retrieved_tools)

            output, thought, action, generated_ids, input_len, attention_mask, \
            input_ids, history = agent_generate_one_step(retrieved_tools, history, actor_model, agent_memory_prompt,
                                                         tokenizer, args, 'tool')
            print(model_name, ' round_num', round_num, ':', output)

            tool_name, code, format_reward, arguments = verify_end('sub-agent', output, action)
            if code > 0:
                action_list.append((tool_name, ''))
            else:
                if action == '':
                    action = ' '
                action_list.append(('output', action))
        elif code == 3:
            round_num += 1
            agent_response = api_calling(model_name, tool_name, query, history, arguments)
            print('call api response: ', agent_response)

            output, thought, action, generated_ids, input_len, attention_mask, \
            input_ids, history = agent_generate_one_step(agent_response, history, actor_model, agent_memory_prompt,
                                                         tokenizer, args, 'tool')
            print(model_name, ' round_num', round_num, ':', output)

            tool_name, code, format_reward, arguments = verify_end('sub-agent', output, action)
            if code > 0:
                action_list.append((tool_name, ''))
            else:
                if action == '':
                    action = ' '
                action_list.append(('output', action))
        else:
            raise ValueError("code value is not in the range！")

    # get final response
    print('history: ', history)
    if code == 0:
        print(model_name, 'sub-agent final output: ', output)
        if model_name == 'math':
            response = output
        else:
            response = action
    elif code > 0:
        print('The maximum number of steps has been reached, but the sub-agent has not finished! Modify the final output of the sub-agent to a unified description.')
        print('The current output of the sub-agent is: ', output)
        response = 'sub-agent timeout planning failed! Fail to get an answer'
        action_list.pop(-1)
        action_list.append(('output', response))
    else:
        print('sub-agent output parsing failed! Modify the final output of the sub-agent to a unified description.')
        print('The current output of the sub-agent is: ', output)
        response = 'sub-agent output parsing failed！'
        action_list.pop(-1)
        action_list.append(('output', response))
    return response, action_list, recall_memory


def multi_agent_generate_answer(query, history, master_model, args, max_round_num=12):
    # Call the query rewriting tool to identify the current user's intention; call the memory retrieval tool to recall similar memories
    print('initial query is: ', query)
    rewrite_query = call_intention_service(query, history)
    print('current rewrite query is: ', rewrite_query)
    recall_memory = call_memory_retrieval(rewrite_query, MASTER_MEMORY_JSON_PATH, MASTER_MEMORY_FAISS_PATH)
    recall_memory = []
    print('recalled memory: ', recall_memory)
    # Put the recalled memory content into prompt
    tool_append, memory_append = parse_recall_memory(recall_memory)
    # The tool description of the master is complete and no further additions are needed
    master_memory_prompt = master_model.prompt.format(tool_append='', memory_append=memory_append)

    output, thought, action, generated_ids, input_len, attention_mask, \
    input_ids, history = agent_generate_one_step(query, history, master_model, master_memory_prompt,
                                                 master_model.tokenizer, args, 'user')

    tool_name, code, format_reward, arguments = verify_end('master', output, action)
    round_num = 1
    rody_response = ""
    # In the following cycle, you can only adjust 12 rounds at most.
    while code > 0 and round_num <= max_round_num:
        # code=1 means the current call of the attention tool
        if code == 1:
            round_num += 1
            intention = call_intention_service(query, history)
            print('call intention response: ', intention)

            output, thought, action, generated_ids, input_len, attention_mask, \
            input_ids, history = agent_generate_one_step(intention, history, master_model, master_memory_prompt,
                                                         master_model.tokenizer, args, 'tool')
            print('master round_num', round_num, ':', output)
            tool_name, code, format_reward, arguments = verify_end('master', output, action)

        # code=2 means calling the search tool and selecting which agent to call
        elif code == 2:
            round_num += 1

            if 'intention' in arguments:
                try:
                    arg_intention = arguments['intention']
                    if isinstance(arg_intention, list):
                        arg_intention = arg_intention[0]
                except Exception as e:
                    arg_intention = arguments['intention']
            else:
                arg_intention = query
            retrieved_tools = call_tool_retrieval(arg_intention, 'master', master_model.tool_database)
            print('call tool retrieval response: ', retrieved_tools)

            output, thought, action, generated_ids, input_len, attention_mask, \
            input_ids, history = agent_generate_one_step(retrieved_tools, history, master_model, master_memory_prompt,
                                                         master_model.tokenizer, args, 'tool')

            print('master round_num', round_num, ':', output)
            tool_name, code, format_reward, arguments = verify_end('master', output, action)
        # code=3 means calling sub-agent
        elif code == 3:
            # round_num += 1
            agent_name_prefix = tool_name.split('_')[0]
            agent_model = globals()[agent_name_prefix + '_model']
            agent_system_prompt = agent_model.prompt
            if 'intention' in arguments:
                try:
                    arg_intention = arguments['intention']
                    if isinstance(arg_intention, list):
                        arg_intention = arg_intention[0]
                except Exception as e:
                    arg_intention = arguments['intention']
            else:
                arg_intention = query

            agent_response, actions, agent_recall_memory = get_response_from_subagent(
                args, agent_model, agent_model.tokenizer, query,
                [], agent_system_prompt, arg_intention, agent_name_prefix)
            print('call sub-agent {} response: '.format(agent_name_prefix), agent_response)
            round_num += len(actions)

            if agent_name_prefix == 'rody':
                rody_response = agent_response.replace('<too_response>', '').replace('</too_response>', '')

            # Return the sub-agent's response to the master to react again
            output, thought, action, generated_ids, input_len, attention_mask, \
            input_ids, history = agent_generate_one_step(agent_response, history, master_model, master_memory_prompt,
                                                         master_model.tokenizer, args, 'tool')
            round_num += 1
            print('master round_num', round_num, ':', output)

            tool_name, code, format_reward, arguments = verify_end('master', output, action)
        else:
            raise ValueError("code value is not in the range！")

    # get final response
    print('final tool_name: ', tool_name)
    if code == 0:
        # Get the content output by the agent
        print('master final output: ', output)
        response = action
    elif code > 0:
        print('The maximum number of steps has been reached, but the master has not finished! Modify the final output of the master to a unified description.')
        print('The current output of the master is: ', output)
        response = 'master timeout planning failed! Fail to get an answer'
    else:
        print('Master output parsing failed! Modify the final output of the master to a unified description.')
        print('The current output of the master is: ', output)
        response = 'master output parsing failed!'

    return response, round_num


def init_params():
    """Initialization parameters"""
    parser = argparse.ArgumentParser()
    parser.add_argument('--test_file_path',
                        default='test_datasets/multi_agent_RL_test_datasets.json',
                        type=str, help='Test Dataset')
    parser.add_argument('--pred_file_path',
                        default='../pred_results/multi_agent_RL_test_datasets_pred.json',
                        type=str, help='Test prediction dataset')
    parser.add_argument('--master_weight_path',
                        default='../models/joint_learning_0512_think_3b_3b_v3/master-checkpoint-1112/weights_actor_model.pth',
                        type=str, help='master model weight model path')
    parser.add_argument('--math_weight_path',
                        default='../models/joint_learning_0512_think_3b_3b_v3/math-checkpoint-1112/weights_actor_model.pth',
                        type=str, help='math model weight model path')
    parser.add_argument('--rody_weight_path',
                        default='../models/joint_learning_0512_think_3b_3b_v3/rody-checkpoint-1112/weights_actor_model.pth',
                        type=str, help='rody model weight model path')
    parser.add_argument('--expert_weight_path',
                        default='../models/joint_learning_0512_think_3b_3b_v3/expert-checkpoint-1112/weights_actor_model.pth',
                        type=str, help='expert model weight model path')
    parser.add_argument('--toolbench_weight_path',
                        default='../models/joint_learning_0512_think_3b_3b_v3/toolbench-checkpoint-1112/weights_actor_model.pth',
                        type=str, help='toolbench model weight model path')
    parser.add_argument('--max_new_token', default=1024, type=int, help='Maximum length of model generation')
    parser.add_argument('--repetition_penalty', default=1, type=float, help='Repeat penalty rate')
    parser.add_argument('--do_sample', default=True, type=bool, help='Random Decoding')
    parser.add_argument('--temperature', default=0.01, type=float, help='Generation temperature')
    parser.add_argument('--seed', default=2048, type=int, help='')
    return parser.parse_args()


def filter_cor(test_dataset):
    test_dataset_new = []
    for data in test_dataset:
        if 'agent' in data:
            agent = data['agent'].replace('_agent', '')
        else:
            agent = data['task_type'].replace('_agent', '')
        if agent not in ['cooperate']:
            continue
        test_dataset_new.append(data)

    return test_dataset_new


def test(test_file_path, pred_file_path, reward_model, args):
    master_model_path = MASTER_SFT_MODEL_PATH
    master_weight_path = args.master_weight_path
    math_model_path = MATH_SFT_MODEL_PATH
    math_weight_path = args.math_weight_path
    rody_model_path = RODY_SFT_MODEL_PATH
    rody_weight_path = args.rody_weight_path
    expert_model_path = EXPERT_SFT_MODEL_PATH
    expert_weight_path = args.expert_weight_path
    toolbench_model_path = TOOLBENCH_SFT_MODEL_PATH
    toolbench_weight_path = args.toolbench_weight_path
    global math_model
    global rody_model
    global expert_model
    global toolbench_model
    master_model = load_model(master_model_path, master_weight_path, MASTER_PROMPT, MASTER_TOOLBASE, 'master')
    math_model = load_model(math_model_path, math_weight_path, MATH_PROMPT, None, 'math')
    rody_model = load_model(rody_model_path, rody_weight_path, RODY_PROMPT, RODY_TOOLBASE, 'rody')
    expert_model = load_model(expert_model_path, expert_weight_path, EXPERT_PROMPT, EXPERT_TOOLBASE, 'expert')
    toolbench_model = load_model(toolbench_model_path, toolbench_weight_path, TOOLBENCH_PROMPT, TOOLBENCH_TOOLBASE,
                                 'toolbench')

    in_file = open(test_file_path, 'r')
    try:
        with open(pred_file_path, 'r') as f:
            pred_file_ori = f.readlines()
        pred_file_ori = [eval(i) for i in pred_file_ori]
        n = len(pred_file_ori)
        print('Completed samples: ', n)
    except FileNotFoundError:
        pred_file_ori = []
        n = 0

    acc = 0
    tot = 0

    with open(pred_file_path, 'a' if n > 0 else 'w') as out_file:
        while 1:
            print(f'data processing: {tot}....')
            line = in_file.readline().strip()
            if not line:
                break
            data = json.loads(line)
            query = data['query']
            history = data['history']
            task_type = data['task_type']
            task_type = task_type.split('_agent')[0]
            real_answer = data['answer']

            tot += 1
            if tot <= n:
                continue

            pred_answer, round_num = multi_agent_generate_answer(query, history, master_model, args)

            if pred_answer in (
            'master timeout planning failed! Fail to get an answer', 'master output parsing failed!'):
                reward_score = -1
            else:
                reward_score = reward_model.get_reward(task_type, query, pred_answer, real_answer)

            if task_type == 'expert':
                if reward_score > 0.7:
                    acc += 1
            else:
                if reward_score > 0:
                    acc += 1

            data['pred_answer'] = pred_answer
            data['round_num'] = round_num
            data['reward'] = reward_score
            new_line = json.dumps(data, ensure_ascii=False)
            out_file.write(new_line + '\n')
            out_file.flush()
    in_file.close()
    out_file.close()


def main():
    args = init_params()
    reward_model = RewardModel()

    test(args.test_file_path, args.pred_file_path, reward_model, args)
    print('finish test!')


if __name__ == '__main__':
    from accelerate import Accelerator

    accelerator = Accelerator()
    if accelerator.is_main_process:
        main()