import os
import sys

workbench_path = '..//WorkBench'
sys.path.append(workbench_path)


from sklearn.metrics.pairwise import cosine_similarity
import logging
import argparse
from sentence_transformers import SentenceTransformer

from prompts import *
from jinja2 import Template
import numpy as np
import re
from collections import defaultdict
from metaflow_workbench import *
del base_url, API_KEY
import pandas as pd
del leaf_nodes

from datetime import datetime

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# ============Config============
def str_to_bool(value):
    return value.lower() in ['true', '1', 't', 'y', 'yes']

parser = argparse.ArgumentParser(description="Parser for model configuration.")

parser.add_argument('--model_name', type=str, required=True, help='Name of the model', choices=['Qwen2.5-7B-Instruct','Qwen2.5-32B-Instruct', 'gpt-4o', 'gpt-4o-mini'])
parser.add_argument('--agent_type', type=str, default='react', help='Type of the agent', choices=[ 'react'])
parser.add_argument('--ICL_selection', type=str, default='random', help='How to select ICL samples', choices=['random', 'most_similar'])
parser.add_argument('--API_domains', type=str, default='domains', help='How to select APIs', choices=['domains', 'all'])
parser.add_argument('--main_reward_name', type=str, default='reward_react', help='How to select APIs', choices=['reward_react', 'reward_privileged_react', 'sft_reward_react'])
parser.add_argument('--pre_filter', type=int, default=1, help='是否采用前置过滤，置入0可能造成metaflow效果下降', choices=[0, 1])
parser.add_argument('--add_expel', type=int, default=0, help='是否加入expel的经验总结', choices=[0, 1])
parser.add_argument('--repetitions', type=int, default=1, help='重复实验次数')
parser.add_argument('--type', type=str, default='pure', choices=['pure', 'metaflow'])


parser.add_argument('--base_url', type=str)
parser.add_argument('--API_KEY', type=str)


args = parser.parse_args()

model_name = args.model_name
base_url = args.base_url
API_KEY = args.API_KEY
agent_type = args.agent_type
repetitions = args.repetitions
API_domains = args.API_domains
main_reward_name = args.main_reward_name
pre_filter = bool(args.pre_filter)


type = args.type


add_expel = bool(args.add_expel)
with open(f'./WorkBench_Expel_{model_name}.pkl', 'rb') as f:
    expel_pool = pickle.load(f)
print(f'Load expel from ', f'./WorkBench_Expel_{model_name}.pkl')

print('model_name:', model_name)
print('agent_type:', agent_type)
print('repetitions:', repetitions)
print('API_domains:', API_domains)
print('main_reward_name:', main_reward_name)
print('pre_filter:', pre_filter)
print('type:', type)
# =======================
temperature = 0.1
experiment_name = "Metaflow_inference"
max_workers = 8
if main_reward_name == 'reward_privileged_react':
    version = "workbench_grpo_v2"
elif main_reward_name == 'reward_react':
    version = "workbench_grpo_v3"
elif main_reward_name == 'sft_reward_react':
    version = 'workbench_sft'
    main_reward_name = 'reward_react'
else:
    raise Exception
set_name_list = ["test"]

K = 1


logger = logging.getLogger('my_logger')
logger.setLevel(logging.DEBUG)

console_handler = logging.StreamHandler()
file_handler = logging.FileHandler('inference_traverse.log')

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)

logger.addHandler(console_handler)
logger.addHandler(file_handler)

logger.debug("This is a debug message")


model_path = 'models/sentence-transformers/all-MiniLM-L6-v2'
model = SentenceTransformer(model_path, device='cpu')
# =======================



def call_llm(messages: list[dict], client, model='gpt-4o-mini', temperature=0.0, max_tokens=8192) -> str:
    """
    Call an LLM with a history of messages and return the response.
    """
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=temperature,
        top_p=1.0,
        n=1,
        frequency_penalty=0.0,
        presence_penalty=0.0,
        logit_bias={},
        max_tokens=max_tokens,
        # seed=123
    )
    text_response = ""
    if response.choices:
        text_response = response.choices[0].message.content
    return text_response


def build_tree_structure( root_node, main_reward_name):
    if 'source' not in root_node:
        return {
            'id': root_node['id'],
            'query': root_node['query'],
            'workflow': root_node['toolcall_and_results'],
            'reward': 1.0,
            main_reward_name:1.0,
            'reward_code_lines_ratio':1.0,
            'height': 1,
            'children': [],
            'domains': root_node['domains'],
            'base_template_list' : [root_node['base_template']]
        }

    meta_node =  {
        'id': root_node['id'],
        'query': root_node['query'],
        'workflow': root_node['workflow'],
        'reward': float(root_node['reward']),
        main_reward_name: float(root_node[main_reward_name]),
        'reward_code_lines_ratio': float(root_node['reward_code_lines_ratio']),
        'children': [build_tree_structure(node, main_reward_name) for node in root_node['source'] ]
    }

    meta_node['height'] = max([node['height'] for node in meta_node['children']]) + 1
    meta_node['domains'] = set().union(*[set(sublist) for sublist in [node['domains'] for node in meta_node['children'] ]])
    meta_node['base_template_list'] = set().union(*[set(sublist) for sublist in [node['base_template_list'] for node in meta_node['children'] ]])
    return meta_node


def build_parent_and_node_dict(root):
    parent_map = {}
    node_dict = {}

    def traverse(node, parent):
        if node is None:
            return
        node_id = node['id']
        node_dict[node_id] = node
        if parent is not None:
            parent_map[node_id] = parent['id']

        # traverse(node.get('left', None), node)
        # traverse(node.get('right', None), node)
        for child_node in node['children']:
            traverse(child_node, node)

    traverse(root, None)
    return parent_map, node_dict


def prune_tree(root, key='reward', threshold=0.8):
    if root is None:
        return []

    parent_map, node_dict = build_parent_and_node_dict(root)
    pruned = set()

    for node_id in node_dict:
        node = node_dict[node_id]
        if node.get(key, 0) < threshold:
            current_id = node_id
            while current_id is not None and current_id not in pruned:
                pruned.add(current_id)
                current_id = parent_map.get(current_id, None)

    retained = [node_id for node_id in node_dict if node_id not in pruned]

    retained_metanodes = [node_id for node_id in retained if len(node_dict[node_id]['children']) > 0]
    forest_roots = []
    for node_id in retained_metanodes:
        parent_id = parent_map.get(node_id, None)
        if parent_id is None:
            forest_roots.append(node_dict[node_id])
        else:
            if parent_id in pruned:
                forest_roots.append(node_dict[node_id])
    
    

    added_root = {
        'id': 'root',
        'query': 'Complete some actions.',
        'workflow': [
            {
                "type": "dynamic",
                "instruction": "Complete user's query",
                "outputs": []
            }
        ],
        "children": forest_roots,
        'height': max([sample['height'] for sample in forest_roots]) + 1,
        'domains': {'analytics', 'calendar', 'customer_relationship_manager', 'email', 'project_management'}
    }
    return added_root




def isNodeValid(query, domains ,task_template,node, model_name, temperature) -> bool:
    if node['id'] == 'root':
        return True
    if node['height'] == 1:
        return False
    for domain in domains:
        if domain not in node['domains']:
            return False

    if pre_filter:
        if task_template not in set(node['base_template_list']):
            return False

    verify_messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user",
         "content": Template(VERIFY_PROMPT_TEMPLATE.lstrip()).render({"query": query,
                                                                      'meta_query': node['query'],
                                                                      'meta_workflow': node['workflow'],
                                                                        })}
    ]

    verify_success = False
    for _ in range(5):
        try:
            verify_result = call_llm(verify_messages, client=OpenAI(base_url=base_url, api_key=API_KEY) ,model=model_name, temperature=temperature)
            # logger.debug(f'\n\nquery: {query}\nmeta_query : {node["query"]}\n meta_workflow: {node['workflow']}\n verify_result: {verify_result}\n\n')
            verify_pattern = r'```json\s*([\s\S]*?)```'
            matches = re.findall(verify_pattern, verify_result, re.DOTALL)
            verify_result = matches[0]

            verify_success = True
            break
        except Exception as e:
            print(f'Verify Error in {verify_messages}', e)
    if verify_success == False:
        raise Exception('Verify Error for 5 times.')

    if 'true' in verify_result.lower():
        return True
    else:
        return False


def traverse(query,domains, task_template, node, model_name, temperature):
    if not isNodeValid(query, domains, task_template, node, model_name, temperature):
        return None

    deepest_node = node

    for child in node.get('children', []):
        child_result = traverse(query, domains, task_template, child, model_name, temperature)
        if child_result is not None:
            if child_result['height'] < deepest_node['height']:
                deepest_node = child_result



    return deepest_node



def DFS_process_task(task):
    task_id, query = task['id'], task['query']
    domains , task_template = task['domains'], task['base_template']
    ICL_examples = ""
    if 'ICL_examples' in task:
        ICL_examples = task['ICL_examples']
    # =========================

    deepest_node = traverse(query, domains, task_template, retained_root, model_name, temperature)

    meta_workflow = deepest_node['workflow']

    leaf_nodes_temp = task['leaf_nodes']

    if API_domains == 'domains':
        domains = leaf_nodes_temp[task_id]['domains']
    elif API_domains == 'all':
        domains = ["analytics", "calendar", "email", "project_management", "crm"]
    else:
        raise Exception(f'API_domains should be "domains" or "all"')
    runner = MetaflowRunner(
        domains=domains,
        base_url=base_url,
        openai_api_key=API_KEY,
        model_name=model_name,
        temperature=temperature,
        agent_type=agent_type,
        ICL_information=ICL_examples
    )
    query = leaf_nodes_temp[task_id]['query']
    result = runner.execute_flow(query, meta_workflow)

    ground_truth = leaf_nodes_temp[task_id]['answer']

    accuracy, side_effects = calculate_metrics_single_in_class(result.get('function_calls'), ground_truth,
                                                               result.get('message'))


    logger.debug(f"query: {query}\nmeta_query: {deepest_node['query']}\n metaflow: {meta_workflow}\n accuracy:{accuracy}\n side_effects:{side_effects}")

    return [leaf_nodes_temp[task_id]['domain'], accuracy, side_effects, result.get('num_llm_calls'), result.get('num_tool_calls'), result]

def root_process_task(task):
    task_id, query = task['id'], task['query']
    ICL_examples = ""
    if 'ICL_examples' in task:
        ICL_examples = task['ICL_examples']
    # =========================

    deepest_node = retained_root

    meta_workflow = deepest_node['workflow']

    leaf_nodes_temp = task['leaf_nodes']

    if API_domains == 'domains':
        domains = leaf_nodes_temp[task_id]['domains']
    elif API_domains == 'all':
        domains = ["analytics", "calendar", "email", "project_management", "crm"]
    else:
        raise Exception(f'API_domains should be "domains" or "all"')
    runner = MetaflowRunner(
        domains=domains,
        base_url=base_url,
        openai_api_key=API_KEY,
        model_name=model_name,
        temperature=temperature,
        agent_type=agent_type,
        ICL_information=ICL_examples
    )
    query = leaf_nodes_temp[task_id]['query']
    result = runner.execute_flow(query, meta_workflow)

    ground_truth = leaf_nodes_temp[task_id]['answer']

    accuracy, side_effects = calculate_metrics_single_in_class(result.get('function_calls'), ground_truth,
                                                               result.get('message'))


    logger.debug(f"query: {query}\nmeta_query: {deepest_node['query']}\n metaflow: {meta_workflow}\n accuracy:{accuracy}\n side_effects:{side_effects}")


    return [leaf_nodes_temp[task_id]['domain'], accuracy, side_effects, result.get('num_llm_calls'), result.get('num_tool_calls'), result]


def cal_succ_rate(eval_results, metric):
    if metric == 'accuracy':
        eval_results = [sample[1] for sample in eval_results]
    elif metric == 'side_effect':
        eval_results = [sample[2] for sample in eval_results]
    elif metric == 'num_llm_calls':
        eval_results = [sample[3] for sample in eval_results]
    elif metric == 'num_tool_calls':
        eval_results = [sample[4] for sample in eval_results]
    else:
        raise Exception(f'metric should be "accuracy" or "side_effect"')
    return sum(eval_results) / len(eval_results)


def cal_succ_rate_by_apps(eval_results, metric):
    domain_to_succ = defaultdict(list)
    domain_list = [sample[0] for sample in eval_results]

    if metric == 'accuracy':
        eval_results = [sample[1] for sample in eval_results]
    elif metric == 'side_effect':
        eval_results = [sample[2] for sample in eval_results]
    else:
        raise Exception(f'metric should be "accuracy" or "side_effect"')
    for domain, eval_result in zip(domain_list, eval_results):
        domain_to_succ[domain].append(eval_result)

    domain_to_rate = {}
    for domain, succ_list in domain_to_succ.items():
        metric_count = sum(succ_list)
        total_count = len(succ_list)

        # Calculate the success rate
        metric_rate = metric_count / total_count if total_count > 0 else 0
        domain_to_rate[domain] = metric_rate

    return domain_to_rate

def get_num_apps(task_id: str) -> int:
    task_metadata_file_path = os.path.join(
        'metaflow_neurips/data', "tasks", task_id, "ground_truth", "metadata.json"
    )

    with open(task_metadata_file_path, 'r') as f:
        metadata = json.load(f)
        
    if metadata['num_apps'] == 1:
        return '1'
    else:
        return '2+'
    # return metadata['num_apps']

if __name__ == '__main__':
    # for task_id in load_task_ids("train"): # Or dev, test_normal, test_challenge
    # samples_for_ICL = pd.read_csv(f'./data/split/train.tsv', sep='\t').to_dict(orient='records')

    with open('./data/split/leaf_nodes_train.pkl', 'rb') as fp:
        leaf_nodes = pickle.load(fp)
    for k, v in leaf_nodes.items():
        v['id'] = k

    samples_for_ICL = list(leaf_nodes.values())


    train_set_embeddings = model.encode([sample['query'] for sample in samples_for_ICL], show_progress_bar=True)

    with open(f'./input/metaflow_tree_{version}.pkl', 'rb') as fp:
        all_nodes = pickle.load(fp)

    module_alias = {
        'crm': 'customer_relationship_manager',
    }
    for node in all_nodes.values():
        if 'domains' in node:
            node['domains'] = [module_alias.get(domain, domain) for domain in node['domains']]
    #

    for set_name in set_name_list:
        leaf_nodes = pd.read_csv(f'./data/split/{set_name}.tsv', sep='\t')
        leaf_nodes = leaf_nodes.to_dict(orient='index')
        for k, v in leaf_nodes.items():
            v['id'] = k

        tasks = list(leaf_nodes.values())

        for task in tasks:
            task['answer'] = eval(task['answer'])
            task['domains'] = eval(task['domains'])
            task['leaf_nodes'] = leaf_nodes
            
        for node in leaf_nodes.values():
            node['domains'] = [module_alias.get(domain, domain) for domain in node['domains']]
        if K > 0:
            for task in tasks:
                task_id, query = task['id'], task['query']
                if args.ICL_selection == 'random':
                    indices = range(0, K, 1)
                    samples = [samples_for_ICL[i] for i in indices]

                    ICL_examples = [f"query:\n{sample['query']}\n\ntool calls and observation:\n```json\n{json.dumps(sample['toolcall_and_results'], indent=2)}\n```" for sample in samples]
                else:
                    query_embedding = model.encode([query])

                    cosine_similarities = cosine_similarity(query_embedding, train_set_embeddings)

                    most_similar_indices = np.argsort(cosine_similarities[0])[::-1][:K]  
                    most_similar_samples = [samples_for_ICL[i] for i in most_similar_indices]


                    ICL_examples = [f"query:\n{sample['query']}\n\ntool calls and observation:\n```json\n{json.dumps(sample['toolcall_and_results'], indent=2)}\n```" for sample in most_similar_samples]

                ICL_examples = '\n'.join(ICL_examples)
                task['ICL_examples'] = ICL_examples
        if add_expel:
            for task in tasks:
                task_id, query = task['id'], task['query']
                expel_experience = expel_pool[query]

                if 'ICL_examples' in task:
                    task['ICL_examples'] += expel_experience
                else:
                    task['ICL_examples'] = expel_experience

        # =====================================
        tasks = tasks * repetitions
        # ====================================
        all_children = set()
        for node in all_nodes.values():
            if 'source' in node:
                all_children.add(node['source'][0]['id'])
                all_children.add(node['source'][1]['id'])

        branch_node = [node for node in all_nodes.values() if node['id'] not in all_children and '|' in node['id']]


        root_node = {
            'id': 'root',
            'query': 'Complete some actions.',
            'workflow': [
                {
                    "type": "dynamic",
                    "instruction": "Complete user's query",
                    "outputs": []
                }
            ],
            'reward': 1.0,
            main_reward_name:1.0,
            'reward_code_lines_ratio':1.0,
            'left': None,
            'right': None,
            'height': 100,
            'source': branch_node,
        }

        # ====================================

        root = build_tree_structure(root_node, main_reward_name=main_reward_name)


        retained_root = prune_tree(root, key=main_reward_name, threshold=1.0)



        # ===========DEBUG==================
        # for task in tasks[200:210]:
        #     DFS_process_task(task)
        #
        #
        # for task in tasks[200:210]:
        #     root_process_task(task)

        # tasks = tasks[0:100]
        # ===============================

        if add_expel:
            Expel = "_Expel"
        else:
            Expel = ""

        if type == 'metaflow':
            with ProcessPoolExecutor(max_workers=max_workers) as executor:
                results = list(tqdm(
                    executor.map(DFS_process_task, tasks),
                    total=len(tasks),
                    desc="Processing DFS tasks"
                ))

            for metric in ['accuracy','side_effect']:
                print(f'{model_name} {agent_type}~MetaFLow {set_name} {API_domains} {metric}:', cal_succ_rate(results, metric))

                print(f'{model_name} {agent_type}~MetaFLow {set_name} {API_domains} {metric} by group:', cal_succ_rate_by_apps(results, metric))


            for metric in ['num_llm_calls', 'num_tool_calls']:
                print(f'{model_name} {agent_type}~MetaFLow {set_name} {API_domains} {metric}:', cal_succ_rate(results, metric))

            with open(f'./output/workbench_results_metaflow_{model_name}_{set_name}{Expel}_{API_domains}_{timestamp}.pkl', 'wb') as fp:
                pickle.dump(results, fp)

            print(f'DUMP to ./output/workbench_results_metaflow_{model_name}_{set_name}{Expel}_{API_domains}_{timestamp}.pkl END')

        elif type == 'pure':
            with ProcessPoolExecutor(max_workers=max_workers) as executor:
                results = list(tqdm(
                    executor.map(root_process_task, tasks),
                    total=len(tasks),
                    desc="Processing pure agent tasks"
                ))

            for metric in ['accuracy','side_effect']:
                print(f'{model_name} pure {agent_type} {set_name} {API_domains} {metric}:', cal_succ_rate(results, metric))

                print(f'{model_name} pure {agent_type} {set_name} {API_domains} {metric} by group:', cal_succ_rate_by_apps(results, metric))


            for metric in ['num_llm_calls', 'num_tool_calls']:
                print(f'{model_name} pure {agent_type} {set_name} {API_domains} {metric}:', cal_succ_rate(results, metric))


            with open(f'./output/workbench_results_pureagent_{model_name}_{set_name}{Expel}_{API_domains}_{timestamp}.pkl', 'wb') as fp:
                pickle.dump(results, fp)
            print(f'DUMP to ./output/workbench_results_pureagent_{model_name}_{set_name}{Expel}_{API_domains}_{timestamp}.pkl END')
