from json_graph import json_graph
import networkx as nx
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from transformers import BertTokenizer, BertModel
import torch
import numpy as np
import requests
import re
from datetime import datetime
import torch
import os
import openai
from dashscope import Generation
import dashscope
import os
from http import HTTPStatus

openai.api_base = "https://api2.aigcbest.top/v1"

os.environ["HF_ENDPOINT"] ="https://hf-mirror.com"
tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-uncased')
model = BertModel.from_pretrained('google-bert/bert-base-uncased')


# 设定设备为GPU（如果有可用的GPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def call_with_stream(messages):

    # 构造用户消息
    api_key = "sk-wL7v4gHTV0m9A6xVC134244361C549B099092bD43bE71fCe"
    user_content = [{
        "type": "text",
        "text": messages
    }]
    
    # 构造请求参数
    payload = {
        "model": "gpt-4-1106-preview",
        "messages": [{"role": "user", "content": user_content}],
        "max_tokens": 4096,
        "temperature": 0,
        "seed": 2024,
    }
    
    # 设置请求头
    headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
    
    # 发送 POST 请求
    response = requests.post(
        "https://api2.aigcbest.top/v1/chat/completions",
        headers=headers,
        json=payload
    )
    
    # 打印生成结果
    print(response.json())
    result = response.json()["choices"][0]["message"]["content"]
    print(result)
    return result



# 函数：使用BERT计算节点和边的语义相似度
def semantic_similarity(prompt, texts):
    # 将prompt和texts转换为BERT输入
    inputs = tokenizer([prompt] + texts, padding=True, truncation=True, return_tensors="pt")
    inputs.to(device)

    # 获取BERT模型的输出（最后一层的隐藏状态）
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()  # 取CLS token作为句子的表示

    # 计算prompt与每个文本之间的余弦相似度
    similarities = cosine_similarity(embeddings[0].reshape(1, -1), embeddings[1:]).flatten()

    return similarities


def llm_generate_keywords(query):
    prompt = (
        "Please follow the search requirements below. **Note: If you cannot derive the answers to the following questions from the search requirements, "
        "please directly return the search requirements without adding any additional information!!!!!!**\n"
        "Consider the following questions:\n"
        "1. What is the research field?\n"
        "2. What is the research problem?\n"
        "3. What are the commonly used methods?\n"
        "Then return the results of your consideration, and concatenate the answers to these three questions.\n"
        "**Note: If you cannot derive the answers to the above questions from the search requirements, please directly return the search requirements without adding any additional information!!!!!!**"
    )
    node = f"****{query}****"
    messages =  prompt + node
    
    # Assuming call_with_stream is a function that processes the messages and returns the result
    keywords = call_with_stream(messages)
    
    return [keywords]


def llm_match_description(description, keywords):
    # 使用 BERT 计算语义相似度
    similarities = semantic_similarity(description, keywords)
    return np.mean(similarities) >= 0.5  # 根据需要调整相似度阈值



def filter_by_time(nodes, graph):
    # 假设我们要根据时间筛选节点，确保每个时间点都有代表性节点
    time_buckets = {}
    for node in nodes:
        timestamp = graph.nodes[node]['timestamp']
        if timestamp not in time_buckets:
            time_buckets[timestamp] = []
        time_buckets[timestamp].append(node)
    
    # 从每个时间点的桶中选择一个或多个节点
    filtered_nodes = []
    for timestamp, nodes in time_buckets.items():
        
        filtered_nodes.extend(nodes[:])  # 每个时间点最多选择两个节点
    return filtered_nodes

def summarize_nodes(nodes, graph):
    # 对筛选出来的节点的 description 进行总结
    summaries = []
    for node in nodes:
        description = graph.nodes[node]['description']
        summary_text = f"{node}: {description[:]}..." if description else f"{node}: No description available."
        summaries.append(summary_text)
    return summaries

def global_search(graph, query):
    keywords = llm_generate_keywords(query)
    relevant_nodes = []
    for node, data in graph.nodes(data=True):
        entity_name = node
        description = data.get('description')
        
        # 处理 description 为 None 的情况
        if data and data['entity_type'] == 'paper':
            if description:
                if llm_match_description(description, keywords):
                    relevant_nodes.append(node)
            else:
                if llm_match_description(entity_name, keywords):
                    relevant_nodes.append(node)
    
    filtered_nodes = filter_by_time(relevant_nodes, graph)
    summary_info = summarize_nodes(filtered_nodes, graph)
    
    return filtered_nodes, summary_info

def extract_detailed_info(nodes, graph):
    detailed_info = {}
    for node in nodes:

        neighbors = list(graph.neighbors(node))
        neighbor_info = []
        for neighbor in neighbors:

            relation = graph[node][neighbor]['relation']
            neighbor_info.append({
                'neighbor': neighbor,
                'relation': relation,
            })
        detailed_info[node] = neighbor_info
    return detailed_info

def prioritize_and_filter(info, query):
    # 对提取的信息进行排序和筛选
    # 在这里可以根据与查询的相关性、邻居数量、关系类型等进行排序
    prioritized_info = sorted(info.items(), key=lambda x: len(x[1]), reverse=True)
    
    # 选择前几个最相关的结果
    filtered_info = prioritized_info[:]  # 假设我们只保留前5个结果
    return filtered_info

def local_search(graph, nodes, query):
    extended_nodes = set(nodes)
    for node in nodes:
        neighbors = list(graph.neighbors(node))
        extended_nodes.update(neighbors)
    
    detailed_info = extract_detailed_info(extended_nodes, graph)
    sorted_info = prioritize_and_filter(detailed_info, query)
    
    return sorted_info

def normalize_timestamp(timestamp):
    if timestamp is None or not isinstance(timestamp, str):
        return datetime.min
    if re.match(r"\d{4}-\d{2}-\d{2}", timestamp):
        return datetime.strptime(timestamp, "%Y-%m-%d")
    elif re.match(r"\d{4}[a-zA-Z]?", timestamp):
        return datetime.strptime(timestamp[:4], "%Y")
    return datetime.min


def sort_nodes_by_timestamp(graph, nodes):
    # Sort nodes by their 'timestamp' attribute using normalized timestamps
    return sorted(nodes, key=lambda x: normalize_timestamp(graph.nodes[x].get('timestamp')))




if __name__ == '__main__':
    csv_path_floder=""
    json_path_floder=""
    G=json_graph(csv_path_floder,json_path_floder)
    query = "LLMs in Medical Domains,LLM-based Multi-agent Collaboration,In the medical field, external tools can be used to improve the performance and reliability of domain-specific language models. "
  # Define and populate the graph G here

    # Perform global search
    filtered_nodes, summary_info = global_search(G, query)
    print("\nLocal Search Detailed Info:")
    print(summary_info)
    # Perform local search
    detailed_results = local_search(G, filtered_nodes, query)

    # Generate time-based CoT narrative using LLM
    cot_narrative,narrative= create_cot_narrative_using_llm(G, [node for node, _ in detailed_results])
    print("CoT Narrative:")
    print(cot_narrative)

    # Print detailed information
    print("\nLocal Search Detailed Info:")
    for node, details in detailed_results:
        print(f"\nNode: {node}")
        for detail in details:
            print(f"  Neighbor: {detail['neighbor']}")
            print(f"  Relation: {detail['relation']}")