from utils import *
from sklearn.metrics.pairwise import cosine_similarity
import logging
import argparse

# ============Config============
def str_to_bool(value):
    return value.lower() in ['true', '1', 't', 'y', 'yes']

parser = argparse.ArgumentParser(description="Parser for model configuration.")

parser.add_argument('--model_name', type=str, required=True, help='Name of the model', choices=['Qwen2.5-7B-Instruct','Qwen2.5-32B-Instruct', 'gpt-4o', 'gpt-4o-mini'])
parser.add_argument('--agent_type', type=str, required=True, help='Type of the agent', choices=['FullCodeRefl', 'react'])
parser.add_argument('--add_Expel', type=str_to_bool, required=True, help='Whether to add expel (True or False)')
parser.add_argument('--ICL_selection', type=str, default='random', help='How to select ICL samples', choices=['random', 'most_similar'])

parser.add_argument('--repetitions', type=int, default=1)

parser.add_argument('--type', type=str, default='pure', choices=['pure', 'metaflow'])

args = parser.parse_args()

model_name = args.model_name
agent_type = args.agent_type
add_Expel = args.add_Expel
repetitions = args.repetitions

type = args.type
# =======================


temperature = 0.1
experiment_name = "Metaflow_inference"
max_workers = 3
meta_tree_version = 'v13'
set_name_list = ["train"]
K = 1

Expel_K = 1
# =======================
logger = logging.getLogger('my_logger')
logger.setLevel(logging.INFO)
logger_handler = logging.FileHandler('inference_traverse.log')
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger_handler.setFormatter(formatter)
logger.addHandler(logger_handler)


model_path = 'models/sentence-transformers/all-MiniLM-L6-v2'
model = SentenceTransformer(model_path, device='cpu')



def build_tree_structure(all_nodes, root_id):
    root_node = all_nodes[root_id]

    if 'source' not in root_node:
        return {
            'id': root_node['id'],
            'query': root_node['query'],
            'workflow': root_node['workflow'],
            'reward': 1.0,
            'reward_privileged_FullCodeRefl':1.0,
            'reward_code_lines_ratio':1.0,
            'left': None,
            'right': None,
            'height': 1,
            'children': [],
        }

    left_node, right_node = root_node['source']

    meta_node =  {
        'id': root_node['id'],
        'query': root_node['query'],
        'workflow': root_node['workflow'],
        'reward': float(root_node['reward']),
        'reward_privileged_FullCodeRefl': float(root_node['reward_privileged_FullCodeRefl']),
        'reward_code_lines_ratio': float(root_node['reward_code_lines_ratio']),
        'left': build_tree_structure(all_nodes, left_node['id']),
        'right': build_tree_structure(all_nodes, right_node['id']),
    }

    meta_node['height'] = max(meta_node['left']['height'], meta_node['right']['height']) + 1
    meta_node['children'] = [meta_node['left'], meta_node['right']]
    return meta_node


def build_parent_and_node_dict(root):
    parent_map = {}
    node_dict = {}

    def traverse(node, parent):
        if node is None:
            return
        node_id = node['id']
        node_dict[node_id] = node
        if parent is not None:
            parent_map[node_id] = parent['id']
        traverse(node.get('left', None), node)
        traverse(node.get('right', None), node)

    traverse(root, None)
    return parent_map, node_dict


def prune_tree(root, key='reward', threshold=0.8):
    if root is None:
        return []

    parent_map, node_dict = build_parent_and_node_dict(root)
    pruned = set()

    for node_id in node_dict:
        node = node_dict[node_id]
        if node.get(key, 0) < threshold:
            current_id = node_id
            while current_id is not None and current_id not in pruned:
                pruned.add(current_id)
                current_id = parent_map.get(current_id, None)

    retained = [node_id for node_id in node_dict if node_id not in pruned]

    retained_metanodes = [node_id for node_id in retained if node_dict[node_id]['left'] is not None or node_dict[node_id]['right'] is not None]
    forest_roots = []
    for node_id in retained_metanodes:
        parent_id = parent_map.get(node_id, None)
        if parent_id is None:
            forest_roots.append(node_dict[node_id])
        else:
            if parent_id in pruned:
                forest_roots.append(node_dict[node_id])
    
    
    

    added_root = {
        'id': 'root',
        'query': 'Perform actions.',
        'workflow': """```json
{
"task_description": "Perform {action}.",
"expected_final_state": "Complete the action."
}
```""",
        "children": forest_roots,
        'height': max([sample['height'] for sample in forest_roots]) + 1,
    }
    return added_root




def isNodeValid(query, node, model_name, temperature) -> bool:
    if node['id'] == 'root':
        return True
    if node['height'] == 1:
        return False
    verify_messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user",
         "content": Template(VERIFY_PROMPT_TEMPLATE.lstrip()).render({"query": query,
                                                                      'meta_query': node['query'],
                                                                      'meta_workflow': node['workflow'],
                                                                        })}
    ]

    verify_success = False
    for _ in range(5):
        try:
            verify_result = call_llm(verify_messages, model=model_name, temperature=temperature)
            logger.debug(f'\n\nquery: {query}\nmeta_query : {node["query"]}\n meta_workflow: {node['workflow']}\n verify_result: {verify_result}\n\n')
            verify_pattern = r'```json\s*([\s\S]*?)```'
            matches = re.findall(verify_pattern, verify_result, re.DOTALL)
            verify_result = matches[0]

            verify_success = True
            break
        except Exception as e:
            print(f'Verify Error in {verify_messages}', e)
    if verify_success == False:
        raise Exception('Verify Error for 5 times.')

    if 'true' in verify_result.lower():
        return True
    else:
        return False


def traverse(query, node, model_name, temperature):
    if not isNodeValid(query, node, model_name, temperature):
        return None

    deepest_node = node

    for child in node.get('children', []):
        child_result = traverse(query, child, model_name, temperature)
        if child_result is not None:
            if child_result['height'] < deepest_node['height']:
                deepest_node = child_result


    logger.debug(f'\n\nquery: {query}\nmeta_query : {deepest_node["query"]}\n meta_workflow: {deepest_node['workflow']}\n\n')

    return deepest_node



def DFS_process_task(task):
    task_id, query = task['id'], task['query']
    ICL_examples = ""
    if 'ICL_examples' in task:
        ICL_examples = task['ICL_examples']
    # =========================


    deepest_node = traverse(query, retained_root, model_name, temperature)

    meta_workflow = deepest_node['workflow']

    eval_result = run_meta_workflow(0, meta_query=deepest_node['query'], meta_workflow_raw=meta_workflow,
                                    experiment_name=experiment_name,
                                    task_id=task_id, model_name=model_name, agent_type=agent_type, repe_index=0,
                                    exec_temperature=temperature, ICL_examples=ICL_examples)
    return [task_id, eval_result]

def root_process_task(task):
    task_id, query = task['id'], task['query']
    ICL_examples = ""
    if 'ICL_examples' in task:
        ICL_examples = task['ICL_examples']
    # =========================
    deepest_node = retained_root

    meta_workflow = deepest_node['workflow']

    eval_result = run_meta_workflow(0, meta_query=deepest_node['query'], meta_workflow_raw=meta_workflow,
                                    experiment_name=experiment_name,
                                    task_id=task_id, model_name=model_name, agent_type=agent_type, repe_index=0,
                                    exec_temperature=temperature, ICL_examples=ICL_examples)
    return [task_id, eval_result]


def cal_succ_rate(eval_results):
    eval_results = [sample[-1] for sample in eval_results]
    succ_num = len([sample for sample in eval_results if sample['eval_result']['success']])
    return succ_num / len(eval_results)


def cal_succ_rate_by_appnums(eval_results):
    appnum_to_succ = defaultdict(list)

    for sample in eval_results:
        task_id, eval_result = sample
        eval_result = eval_result['eval_result']
        num_apps = get_num_apps(task_id)  # Get number of applications for the task

        appnum_to_succ[num_apps].append(eval_result['success'])

    # Now calculate success rates for each group of tasks based on num_apps
    appnum_to_rate = {}
    for num_apps, succ_list in appnum_to_succ.items():
        success_count = sum(succ_list)
        total_count = len(succ_list)

        # Calculate the success rate
        success_rate = success_count / total_count if total_count > 0 else 0
        appnum_to_rate[num_apps] = success_rate

    return appnum_to_rate


def get_num_apps(task_id: str) -> int:
    task_metadata_file_path = os.path.join(
        'metaflow_neurips/data', "tasks", task_id, "ground_truth", "metadata.json"
    )

    with open(task_metadata_file_path, 'r') as f:
        metadata = json.load(f)
        
    if metadata['num_apps'] == 1:
        return '1'
    else:
        return '2+'
    # return metadata['num_apps']

if __name__ == '__main__':
    predefined_expel_input = {
        'train' : {},
    }
    for set_name in set_name_list:
        with open(f'./input/leaf_nodes_{set_name}.pkl', 'rb') as fp:
            leaf_nodes = pickle.load(fp)

        train_set_embeddings = model.encode([sample['query'] for sample in leaf_nodes], show_progress_bar=True)

        with open(f'./input/metaflow_tree_{meta_tree_version}_{set_name}.pkl', 'rb') as fp:
            all_nodes = pickle.load(fp)


        with open('./input/NEW_ALL_LEAF_NODES.pkl', 'rb') as f:
            ALL_LEAF_NODES = pickle.load(f)

        tasks = ALL_LEAF_NODES[set_name]

        if K >0:
            for task in tasks:
                task_id, query = task['id'], task['query']
                if args.ICL_selection == 'random':
                    indices = range(0, K, 1)
                    samples = [leaf_nodes[i] for i in indices]

                    ICL_examples = [f"query:\n{sample['query']}\n\nsolution:\n```python\n{sample['workflow']}\n```" for sample in samples]

                else:
                    #
                    query_embedding = model.encode([query])
                    #
                    cosine_similarities = cosine_similarity(query_embedding, train_set_embeddings)

                    #
                    most_similar_indices = np.argsort(cosine_similarities[0])[::-1][:K]  #
                    most_similar_samples = [leaf_nodes[i] for i in most_similar_indices]

                    ICL_examples = [f"query:\n{sample['query']}\n\nsolution:\n```python\n{sample['workflow']}\n```" for sample in most_similar_samples]

                ICL_examples = '\n'.join(ICL_examples)
                task['ICL_examples'] = ICL_examples

        # =====================================
        if add_Expel and Expel_K >0:
            for task in tqdm(tasks,'calculating experience'):
                task_id, query = task['id'], task['query']
                query_embedding = model.encode([query])

                #
                cosine_similarities = cosine_similarity(query_embedding, train_set_embeddings)

                #
                most_similar_indices = np.argsort(cosine_similarities[0])[::-1][:K]  #
                most_similar_samples = [leaf_nodes[i] for i in most_similar_indices]
                ICL_examples = [f"query:\n{sample['query']}\n\nsolution:\n```python\n{sample['workflow']}\n```" for sample
                                in most_similar_samples]

                ICL_examples = '\n'.join(ICL_examples)

                expel_messages = [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {"role": "user",
                     "content": "Please summarize the experience in natural language from the following examples so that you can use this experience in subsequent similar tasks. Please don't rehash the code, don't discuss the details of the code, just the empirical guidelines in natural language form, and don't include anything else.\n\n" + ICL_examples }
                ]
                expel_result = ""
                expel_success = False
                for _ in range(5):
                    try:
                        expel_result = call_llm(expel_messages, model=model_name, temperature=temperature)
                        expel_success = True

                        logger.debug(
                            f'\n\nquery: {query}\n expel_result: {expel_result}\n\n')

                        break
                    except Exception as e:
                        print(f'Expel Error in {expel_messages}', e)
                if expel_success == False:
                    raise Exception('Expel Error for 5 times.')

                expel_result = f"\n\nHere are some relevant historical experiences:\n{expel_result}"
                predefined_expel_input[set_name][query] = expel_result
        if add_Expel and Expel_K > 0:
            for task in tqdm(tasks,'adding experience'):
                task_id, query = task['id'], task['query']

                expel_result = predefined_expel_input[set_name][query]

                if 'ICL_examples' not in task:
                    task['ICL_examples'] = expel_result
                else:
                    task['ICL_examples'] += expel_result

        # ====================================
        tasks = tasks * repetitions
        # ====================================

        if set_name == 'train':
            final_root_id = "07b42fd_1|07b42fd_2|07b42fd_3|229360a_1|229360a_2|229360a_3|22cc237_1|22cc237_2|22cc237_3|27e1026_1|27e1026_2|27e1026_3|287e338_1|287e338_2|287e338_3|29caf6f_1|29caf6f_2|29caf6f_3|2a163ab_1|2a163ab_2|2a163ab_3|302c169_1|302c169_2|302c169_3|34d9492_1|34d9492_2|34d9492_3|3c13f5a_1|3c13f5a_2|3c13f5a_3|60d0b5b_1|60d0b5b_2|60d0b5b_3|6104387_1|6104387_2|6104387_3|692c77d_1|692c77d_2|692c77d_3|6ea6792_1|6ea6792_2|6ea6792_3|76f2c72_1|76f2c72_2|76f2c72_3|771d8fc_1|771d8fc_2|771d8fc_3|7d7fbf6_1|7d7fbf6_2|7d7fbf6_3|82e2fac_1|82e2fac_2|82e2fac_3|aa8502b_1|aa8502b_2|aa8502b_3|afc0fce_1|afc0fce_2|afc0fce_3|b0a8eae_1|b0a8eae_2|b0a8eae_3|b7a9ee9_1|b7a9ee9_2|b7a9ee9_3|c901732_1|c901732_2|c901732_3|ccb4494_1|ccb4494_2|ccb4494_3|ce359b5_1|ce359b5_2|ce359b5_3|cf6abd2_1|cf6abd2_2|cf6abd2_3|d0b1f43_1|d0b1f43_2|d0b1f43_3|e3d6c94_1|e3d6c94_2|e3d6c94_3|e7a10f8_1|e7a10f8_2|e7a10f8_3|e85d92a_1|e85d92a_2|e85d92a_3"
        elif set_name == 'dev':
            final_root_id = "0d8a4ee_1|0d8a4ee_2|0d8a4ee_3|23cf851_1|23cf851_2|23cf851_3|37a8675_1|37a8675_2|37a8675_3|383cbac_1|383cbac_2|383cbac_3|396c5a2_1|396c5a2_2|396c5a2_3|3ab5b8b_1|3ab5b8b_2|3ab5b8b_3|4ec8de5_1|4ec8de5_2|4ec8de5_3|4fab96f_1|4fab96f_2|4fab96f_3|50e1ac9_1|50e1ac9_2|50e1ac9_3|530b157_1|530b157_2|530b157_3|57c3486_1|57c3486_2|57c3486_3|6171bbc_1|6171bbc_2|6171bbc_3|68ee2c9_1|68ee2c9_2|68ee2c9_3|6bdbc26_1|6bdbc26_2|6bdbc26_3|6c2c621_1|6c2c621_2|6c2c621_3|b119b1f_1|b119b1f_2|b119b1f_3|d4e9306_1|d4e9306_2|d4e9306_3|df61dc5_1|df61dc5_2|df61dc5_3|fac291d_1|fac291d_2|fac291d_3"
        else:
            raise NotImplementedError

        root = build_tree_structure(all_nodes, final_root_id)


        retained_root = prune_tree(root, key='reward_privileged_FullCodeRefl', threshold=1.0)

        if add_Expel:
            Expel = '_Expel'
        else:
            Expel = ""


        if type == 'metaflow':
            with ProcessPoolExecutor(max_workers=max_workers) as executor:
                results = list(tqdm(
                    executor.map(DFS_process_task, tasks),
                    total=len(tasks),
                    desc="Processing DFS tasks"
                ))

            print(f'{model_name} {agent_type}~MetaFLow {Expel} {set_name} success ratio:', cal_succ_rate(results))

            print(f'{model_name} {agent_type}~MetaFLow {Expel} {set_name} success ratio by group:', cal_succ_rate_by_appnums(results))


            with open(f'./last_output/MetaFlow_{model_name}_{agent_type}{Expel}_{args.ICL_selection}_{set_name}.pkl', 'wb') as f:
                pickle.dump(results, f)
        elif type == 'pure':
            with ProcessPoolExecutor(max_workers=max_workers) as executor:
                results = list(tqdm(
                    executor.map(root_process_task, tasks),
                    total=len(tasks),
                    desc="Processing pure agent tasks"
                ))

            with open(f'./last_output/PureAgent_{model_name}_{agent_type}{Expel}_{args.ICL_selection}_{set_name}.pkl', 'wb') as f:
                pickle.dump(results, f)


            print(f'{model_name} pure {agent_type} {Expel} {set_name} success ratio:', cal_succ_rate(results))

            print(f'{model_name} pure {agent_type} {Expel} {set_name} success ratio by group:', cal_succ_rate_by_appnums(results))