import json
import torch
from sentence_transformers import SentenceTransformer

device = "cuda"
embedding_name = "facebook/contriever"
domain = "test"
corpus_path = f"./{domain}_merge.json"
embeddings_path = f"./{domain}_embeddings_retriever.json"

embedder = SentenceTransformer(embedding_name, trust_remote_code=True, device=device).eval()


def encode_texts(texts):
    texts = [str(t) for t in texts]
    return (
        embedder.encode(
            texts,
            batch_size=1,
            convert_to_tensor=True,
            device=device,
            normalize_embeddings=True,
        )
        .detach()
        .clone()
    )


with open(corpus_path, "r", encoding="utf-8") as f:
    merged_corpus = json.load(f)

corpus_texts = []
for group in merged_corpus:
    corpus_texts.extend(group)

print(len(corpus_texts))

embeddings = encode_texts(corpus_texts).cpu().numpy().tolist()
print(len(embeddings))

with open(embeddings_path, "w", encoding="utf-8") as f:
    json.dump(embeddings, f)
