import json
import numpy as np
import hdbscan
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import argparse

def refine_knowledge_base(input_path: str, output_path: str, similarity_threshold: float = 0.95):
    """
    使用语义聚类来精炼知识库，去除语义上重复的问题。
    """
    print("Loading knowledge base...")
    with open(input_path, 'r', encoding='utf-8') as f:
        kb_data = json.load(f)

    # 提取所有问题用于向量化
    questions = [record.get('question', '') for record in kb_data]

    print("Loading sentence transformer model...")
    # 加载一个轻量级但效果很好的句向量模型
    # model = SentenceTransformer('all-MiniLM-L6-v2')
    # model = SentenceTransformer("BAAI/bge-small-en")
    model = SentenceTransformer("BAAI/bge-base-en-v1.5")

    print("Encoding questions into vectors... (This may take a moment)")
    embeddings = model.encode(questions, show_progress_bar=True)

    print("Clustering questions with HDBSCAN...")
    # 使用 HDBSCAN 进行聚类，它能自动确定簇的数量
    # min_cluster_size=2 表示至少有两个相似的问题才算一个簇
    clusterer = hdbscan.HDBSCAN(min_cluster_size=2, metric='euclidean', gen_min_span_tree=True)
    clusterer.fit(embeddings)

    # labels 中，-1 表示噪声（独立的问题），其他数字表示所属的簇 ID
    labels = clusterer.labels_
    
    refined_indices = set()
    processed_clusters = set()

    print(f"Found {labels.max() + 1} clusters. Selecting representatives...")
    
    # 遍历所有记录，决定保留哪些
    for i, label in enumerate(tqdm(labels, desc="Selecting Representatives")):
        if label == -1:
            # 噪声点（独立问题），直接保留
            refined_indices.add(i)
        elif label not in processed_clusters:
            # 这是一个新的簇，我们需要从中挑选一个代表
            cluster_indices = np.where(labels == label)[0]
            
            # 计算簇内所有向量的中心点（质心）
            cluster_embeddings = embeddings[cluster_indices]
            centroid = np.mean(cluster_embeddings, axis=0)
            
            # 计算簇内每个向量与质心的相似度
            similarities = cosine_similarity(cluster_embeddings, centroid.reshape(1, -1))
            
            # 选择与质心最相似的那个作为代表
            representative_local_index = np.argmax(similarities)
            representative_global_index = cluster_indices[representative_local_index]
            
            refined_indices.add(representative_global_index)
            processed_clusters.add(label)

    # 根据筛选出的索引，构建新的知识库
    refined_kb = [kb_data[i] for i in sorted(list(refined_indices))]

    print(f"\nRefinement complete.")
    print(f"Original records: {len(kb_data)}")
    print(f"Refined records: {len(refined_kb)}")

    print(f"Saving refined knowledge base to: {output_path}")
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(refined_kb, f, indent=2, ensure_ascii=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Refine a knowledge base by removing semantically duplicate questions via clustering.")
    parser.add_argument("--input_file", type=str, default='/home/ofo/project_workflow_auto_generation/smolagents/examples/open_deep_research/data/agent_kb_database.json', help="Path to the input JSON knowledge base file (e.g., kb.json).")
    parser.add_argument("--output_file", type=str, default='/home/ofo/project_workflow_auto_generation/smolagents/examples/open_deep_research/data/agent_kb_database_refined_v2.json', help="Path to save the refined JSON file.")
    
    args = parser.parse_args()
    refine_knowledge_base(args.input_file, args.output_file)