import chromadb
import nltk
from tqdm import tqdm
from nltk.corpus import wordnet as wn
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

import torch
from sentence_transformers import SentenceTransformer, util

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# 加载预训练模型并将其移动到选择的设备
PLM_name = "BAAI/bge-m3"  # 或 'all-MiniLM-L6-v2'
print(f"Loading model: {PLM_name}")
model = SentenceTransformer(PLM_name, cache_folder='../pypro/huggingface_models/hub').to(device)
print(f"Model loaded and moved to {device}.")

# 定义嵌入生成函数
def encode_text(text):
    # 将文本转换为嵌入
    with torch.no_grad():  # 禁用梯度计算，以节省内存和计算
        embeddings = model.encode(text, convert_to_tensor=True, device=device)
    return embeddings.cpu().numpy()  # 将结果转回到CPU，并转换为numpy数组

# 集合名称
collection_name = "wordnet_cosine"
chroma_client = chromadb.chromadb.PersistentClient(path="../chroma_data")
# 检查并删除现有的集合（如果存在）
if collection_name in [i.name for i in chroma_client.list_collections()]:
    chroma_client.delete_collection(collection_name)
    print(f"Collection '{collection_name}' has been deleted.")

# 创建新的集合
collection = chroma_client.create_collection(name=collection_name, metadata={"hnsw:space": "cosine"})
print(f"New collection '{collection_name}' has been created.")

words = set()  # 用于存储所有单词，避免重复

pos_templates = {
    'n': 'It is a {}.',        # 名词模板
    'v': 'It can {}.',         # 动词模板
    'a': 'It is {}.',          # 形容词模板
}

# 初始化列表来存储批量数据
documents = []
embeddings = []  # 假设你有生成向量的逻辑
metadatas = []
ids = []
# Define the maximum batch size
MAX_BATCH_SIZE = 41666
# 遍历 WordNet 中的所有 synset
word_set = set()
for synset in tqdm(list(wn.all_synsets())):
    for lemma in synset.lemmas():
        pos = synset.pos()
        if pos not in ['n', 'v', 'a']:
            continue

        word = lemma.name()
        if word in word_set:
            continue
        word_set.add(word)

        if '-' in word or '_' in word or "'" in word:
            continue

        definition = synset.definition()
        sentence = f"It is a {word}. Definition: {definition}"

        # 生成或获取 embedding 向量（这里假设你有一个 encode_text 函数）
        embedding = encode_text(sentence)

        # 添加到列表中
        documents.append(sentence)
        embeddings.append(embedding.tolist())  # 转换为列表以存储
        metadatas.append({"word": word, "pos": pos, "definition": definition})
        ids.append(f"{word}_{pos}_{synset.offset()}")  # 创建一个唯一的 ID
        if len(documents) >= MAX_BATCH_SIZE:
            # Add the batch to the collection
            collection.add(
                documents=documents,
                embeddings=embeddings,
                metadatas=metadatas,
                ids=ids
            )
            # Clear the lists for the next batch
            documents.clear()
            embeddings.clear()
            metadatas.clear()
            ids.clear()

# Add the remaining documents if any
if documents:
    collection.add(
        documents=documents,
        embeddings=embeddings,
        metadatas=metadatas,
        ids=ids
    )

print(f"Stored {len(documents)} items into the collection '{collection_name}' in ChromaDB.")

print(collection.query(
    query_embeddings=encode_text('cat').tolist(),
))

pass



