import config
import os
from graph_db import Neo4jDatabase
import asyncio
from traj_to_kg import TrajectoryToNeo4jImporter, find_all_task_folders
import json
import numpy as np
from graph_db import ActionMergeAnalyzer
from node_summary import LLMFunctionSummarizer, LLMElementName
import logging
from typing import Dict, Any, List, Optional, Union
from chain_evolve import format_chain_operations, extract_element_details, extract_reasoning_results, create_action_generation_chain, create_action_node_in_db, create_action_element_relations




def demonstrate_BPE(task_chains_input):
    """
    演示如何使用ActionMergeAnalyzer

    Args:
        task_chains_input: 任务链输入，可以是列表、NumPy数组等
    """
    try:
        # 创建分析器并运行
        analyzer = ActionMergeAnalyzer(verbose=True)
        results = analyzer.analyze(task_chains_input, num_merges=20, min_freq=2)


        # 打印结果
        print("\n=== High Level Actions Found ===")
        print(f"Total high-level actions identified: {results['total_found']}")

        print("\nTop High-Level Actions by Frequency:")
        for name, action in list(results['high_level_actions'].items())[:10]:
            print(f"\n{name}:")
            print(f"  Frequency: {action['frequency']}")
            print(f"  Level: {action['level']} (higher means more complex)")
            print(f"  Components: {action['num_components']} actions")
            print(f"  Preview: {action['components_preview']}")
        return results

    except Exception as e:
        import traceback
        print(f"Error during demonstration: {str(e)}")
        print(traceback.format_exc())





def Action_envolving(root_path, database, index):
    """演示如何使用该功能"""

    # 配置Neo4j连接（请替换为实际连接信息）
    db = TrajectoryToNeo4jImporter(
        uri=config.Neo4j_URI,
        auth=config.Neo4j_AUTH,
        database=database,
        index=index
    )



    # 获取所有other_info中step为0的Page节点
    tasks = find_all_task_folders(root_path)
    start_page = db.find_unique_start_page_id()
    print(start_page)
    task_chains = []
    for task in tasks:
        # if i > 1:
        #     break
        all_paths = []
        seen_path_ids = set()
        print(f"开始提取任务：{task}")

        for task_chain in db.find_task_paths_lazy(start_page_id=start_page, target_task=task.name):
            path_id = task_chain['path_id']
            # print(task_chain)
            seen_chain = set()
            if path_id not in seen_path_ids:
                chain_key = tuple(
                    (
                        t["source_page"],
                        t["element"],
                        t["target_page"],
                    )
                    for t in task_chain['triplets']
                )
                seen_chain.add(chain_key)
                seen_path_ids.add(path_id)
                all_paths.append(seen_chain)

        print(f"Found {len(all_paths)} unique paths for task: {task.name}")
        task_chains.append(all_paths)
        # chains = db.find_all_task_paths(task.name)
        # chain_list.append(chains)
    task_chains = np.concatenate(task_chains, axis=0)

    results = demonstrate_BPE(task_chains)
    print(results)
    for name, result in list(results['high_level_actions'].items()):

        chain = result['components_preview']

        if not db.is_action_duplicate(chain):
            print("generating new action node ... ")
            print(f"Chain length: {len(chain)}")
            print(chain[-1])
            action_chain, additional_targets = db.get_chain_by_chain_id(chain)
            # print(additional_targets)
            print(f"Retrieved action_chain length: {len(action_chain)}")
            action_data = generate_action_node(action_chain, additional_targets)
            action_data = create_action_node_in_db(action_data, chain, db)
            relations_success = create_action_element_relations(action_data, db)

            if not relations_success:
                print("Some element relations creation failed")

            print(
                f"Successfully completed chain evolution, created high-level action node: {action_data['name']} (ID: {action_data['action_id']})"
            )
        else:
            print('pass existing action node')
            continue





    # 关闭连接
    db.close()


def get_common_descriptions(chain: List[Dict[str, Any]]) -> List[str]:
    """
    比较任务链中第一页与最后一页的任务描述，返回共有的描述

    Args:
        chain: 任务链

    Returns:
        第一页和最后一页共有的任务描述列表
    """

    if not chain:
        return []

    # 找出第一个三元组的source_page和最后一个三元组的target_page
    first_page = chain[0].get("source_page") if chain else None
    last_page = chain[-1].get("target_page") if chain else None

    # 确保两个页面都存在
    if not first_page or not last_page:
        return []

    # 辅助函数：从页面中提取任务描述
    def extract_descriptions(page):
        descriptions = set()
        if "other_info" in page:
            other_info = page["other_info"]

            # 解析JSON字符串
            if isinstance(other_info, str):
                try:
                    other_info = json.loads(other_info)
                except:
                    return set()

            # 提取描述
            if isinstance(other_info, list):
                for item in other_info:
                    if isinstance(item, dict) and "task_info" in item:
                        desc = item.get("task_info", {}).get("description")
                        if desc:
                            descriptions.add(desc)
        return descriptions

    # 提取两个页面的描述并计算交集
    first_descriptions = extract_descriptions(first_page)
    last_descriptions = extract_descriptions(last_page)

    common_descriptions = first_descriptions.intersection(last_descriptions)

    return list(common_descriptions)



def generate_action_node(chain, additional_targets) -> Optional[Dict[str, Any]]:
    """Generate high-level action node content.

    Args:
        chain: Triplet chain

    Returns:
        Generated high-level action node content (dictionary)
    """
    # Create generation chain
    generation_chain = create_action_generation_chain()

    # Prepare generation input
    # task_description = get_common_descriptions(chain)
    task_description = "unknown_task"
    chain_operations = format_chain_operations(chain, additional_targets)
    element_details = extract_element_details(chain)
    reasoning_results = extract_reasoning_results(chain)

    generation_input = {
        "task_description": task_description,
        "chain_operations": chain_operations,
        "element_details": element_details,
        "reasoning_results": reasoning_results,
    }

    try:
        # Execute generation - note that this returns a dictionary rather than a Pydantic object
        generation_result = generation_chain.invoke(generation_input)

        # Check if the returned result is a valid dictionary
        if isinstance(generation_result, dict) and "action_id" in generation_result:
            return generation_result
        else:
            print(
                f"Warning: The format of the generation result returned by LLM is incorrect: {generation_result}"
            )
            return None
    except Exception as e:
        print(f"Error generating high-level action node: {str(e)}")
        return None

def main():
    # Neo4j连接配置
    importer = TrajectoryToNeo4jImporter(
        uri=config.Neo4j_URI,
        auth=config.Neo4j_AUTH,
        database=config.Neo4j_DATABASE,
        index=config.PINECONE_INDEX
    )

    try:
        # 开始DFS遍历和导入
        root_path = "D:/Project/LLM_project/android_world/android_world-main/exploration_output/com.flauschcode.broccoli"
        tasks = find_all_task_folders(root_path)
        print("*****************************************************************************************")
        print("Start building Knowledge Graph .... ")
        for task in tasks:
            print(f"开始遍历任务：{task}")
            importer.dfs_traverse_and_import(task)
        print("Building Successful ! ")

        print("*****************************************************************************************")
        print("Start chain understanding .... ")
        for task in tasks:
            print(f"开始遍历任务：{task}")
            # 使用 asyncio.run() 运行异步函数
            result = asyncio.run(importer.chain_understand(task.name))
            print(f"任务 {task.name} 处理结果: {len(result) if result else 0} 个chains")
            # i += 1
        print("Understanding Successful ! ")
        print("*****************************************************************************************")
        print("Start merging high-level-action .... ")
        Action_envolving(root_path)
        print("merging Successful ! ")

        # summarizer = LLMFunctionSummarizer(
        #     config.Neo4j_URI, config.Neo4j_AUTH, config.Neo4j_DATABASE, config.LLM_API_KEY, True
        # )
        summarizer = LLMFunctionSummarizer(
            config.Neo4j_URI, config.Neo4j_AUTH, config.Neo4j_DATABASE, config.LLM_API_KEY, True
        )
        summarizer.process_all_nodes()
        summarizer.close()

        importer.optimize_paths_with_action_groups(root_path)

    finally:
        importer.close()


def KG_Construction(database, index, package, base_path):
    # Neo4j连接配置
    importer = TrajectoryToNeo4jImporter(
        uri=config.Neo4j_URI,
        auth=config.Neo4j_AUTH,
        database=database,
        index=index
    )
    db = Neo4jDatabase(uri=config.Neo4j_URI, auth=config.Neo4j_AUTH, database=database, index=index)

    try:
        # # 开始DFS遍历和导入
        base_path = base_path
        root_path = os.path.join(base_path, package)
        tasks = find_all_task_folders(root_path)
        print("*****************************************************************************************")
        print("Start building Knowledge Graph .... ")
        for task in tasks:

            print(f"开始遍历任务：{task}")
            importer.dfs_traverse_and_import(task)
        print("Building Successful ! ")

        print("*****************************************************************************************")
        print("Start chain understanding .... ")
        for task in tasks:

            print(f"开始遍历任务：{task}")
            # 使用 asyncio.run() 运行异步函数
            result = asyncio.run(importer.chain_understand(task.name, db))
            print(f"任务 {task.name} 处理结果: {len(result) if result else 0} 个chains")
            # i += 1
        print("Understanding Successful ! ")
        print("*****************************************************************************************")
        print("Start merging high-level-action .... ")
        Action_envolving(root_path, database, index)
        print("merging Successful ! ")

        summarizer = LLMFunctionSummarizer(
            config.Neo4j_URI, config.Neo4j_AUTH, database, config.LLM_API_KEY, False
        )
        summarizer.process_all_nodes()
        summarizer.close()

        importer.optimize_paths_with_action_groups(root_path)

    finally:
        importer.close()

if __name__ == "__main__":
    main()
