# 获取实体，找到对应的chunk，然后给问题和chunk来回答问题，chunk太长就再抽一遍子图（问题+chunk，带着问题抽子图）

# 连接neo4j
import py2neo
from py2neo import Graph
import json
from OpenAI import OpenAI

client = OpenAI(base_url="http://localhost:8000/v1", api_key="sk-xxx")

def connect_neo4j():
    
    graph = Graph("bolt://localhost:7687", auth=("neo4j", "password"))
    return graph

def subgraph_refine(chunk, question):
    # 连接到Neo4j数据库
    graph = connect_neo4j()
    # prompt设计: 根据问题和chunk来抽子图triplet
    prompt = f"根据问题和chunk来抽子图triplet，问题:{question}，chunk:{chunk}.输出json格式的triplet，比如 [{'head': '实体1', 'relation': '关系', 'tail': '实体2'}]"
    # 提取实体
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "你是一个能够从文本中提取知识图谱三元组的助手。"},
            {"role": "user", "content": prompt}
        ],
        max_tokens=500,
        temperature=0.2,
    )
    triplets = response.choices[0].message['content']
    print("提取的triplets:", triplets)
    # 解析triplets并存入Neo4j
    try:    
        triplet_list = json.loads(triplets.replace("'", "\""))
        for triplet in triplet_list:
            head = triplet['head']
            relation = triplet['relation']
            tail = triplet['tail']
            # 创建节点和关系
            graph.run("""
                MERGE (a:Entity {name: $head})
                MERGE (b:Entity {name: $tail})
                MERGE (a)-[r:RELATION {type: $relation}]->(b)
                """, head=head, tail=tail, relation=relation)
        print("子图已存入Neo4j")
    except json.JSONDecodeError:
        print("无法解析triplets:", triplets)
    except Exception as e:
        print("存入Neo4j时出错:", e)
    return triplets

# 测试
if __name__ == "__main__":
    chunk = "人工智能（Artificial Intelligence，简称AI）是计算机科学的一个分支，旨在创造能够执行通常需要人类智能的任务的系统。AI技术包括机器学习、自然语言处理、计算机视觉等。近年来，随着大数据和计算能力的提升，AI在医疗、金融、交通等领域得到了广泛应用。"
    question = "什么是人工智能？"
    subgraph_refine(chunk, question)
    