from sentence_transformers import SentenceTransformer, util
import os 
import json
from tqdm import tqdm
import numpy as np
import torch
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# 加载预训练模型
model = SentenceTransformer('stsb-roberta-large')


def extract_and_filter_sentences(file_path, threshold=0.5):
    with open(file_path, 'r') as file:
        lines = file.readlines()

    sentences = []
    for line in lines:
        data = json.loads(line)
        sentences.append(data['answer'])
    print('len(sentences):', len(sentences))

    # 初始化库
    ids = [0]
    library = [sentences[0]]
    library_embeddings = [model.encode(sentences[0], convert_to_tensor=True)]

    for i in tqdm(range(1, len(sentences))):
        sentence = sentences[i]
        embedding = model.encode(sentence, convert_to_tensor=True)
        similarities = [util.pytorch_cos_sim(embedding, lib_emb).item() for lib_emb in library_embeddings]
        # print(f'sentence: {sentence}')
        # print(f'library: {library[0]}')
        # print(f"Similarities: {similarities}")
        # exit()
        if max(similarities) < threshold:
            library.append(sentence)
            library_embeddings.append(embedding)
            ids.append(i)

    return ids, library


def get_emb(model, file_path, outpath):

    with open(file_path, 'r') as file:
        lines = file.readlines()

    sentences = []
    for line in lines:
        data = json.loads(line)
        sentences.append(data['intro'])
    print('len(drugs):', len(sentences))

    embeddings = []
    for i in tqdm(range(len(sentences))):
        sentence = sentences[i]
        embedding = model.encode(sentence, convert_to_tensor=True)
        # embedding = get_emb_8B(model, tokenizer, prompt=sentence)
        embeddings.append(embedding.detach().cpu().numpy())
    
    embeddings = np.array(embeddings)
    np.savez(outpath, emb=embeddings)
    print(embeddings.shape, 'save done')  # (1710, 1024)


def get_name_emb(model, file_path,outpath):

    with open(file_path, 'r') as f:
        data = json.load(f)
    
    # 按照序号顺序提取句子
    sentences = [data[str(i)]['name'] for i in range(len(data))]
    embeddings = []

    for i in tqdm(range(len(sentences))):
        sentence = sentences[i]
        embedding = model.encode(sentence, convert_to_tensor=True)
        # embedding = get_emb_8B(model, tokenizer, prompt=sentence)
        embeddings.append(embedding.detach().cpu().numpy())
    
    embeddings = np.array(embeddings)
    np.savez(outpath, emb=embeddings)
    print(embeddings.shape, 'save done')  # DB (109, 1024)  TS (223, 1024)
    


def get_emb_8B(model, tokenizer, prompt):
    # 准备输入文本
    input_text = prompt   #"这是一个示例文本，计算它的嵌入。"

    # 对输入文本进行分词并转换为模型输入
    inputs = tokenizer(input_text, return_tensors="pt")

    # 将输入数据移动到GPU（如果有的话）
    inputs = {key: value.to(model.device) for key, value in inputs.items()}

    # 获取模型的输出（包括所有层的hidden states）
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)

    # 获取第一层和最后一层的hidden_state
    hidden_states = outputs.hidden_states
    first_layer_hidden_state = hidden_states[0]  # 第一层的hidden_state
    last_layer_hidden_state = hidden_states[-1]  # 最后一层的hidden_state

    # 计算第一层与最后一层的平均值
    mean_embedding = (first_layer_hidden_state.mean(dim=1) + last_layer_hidden_state.mean(dim=1)) / 2

    # 如果有多个GPU，可以将结果移到CPU
    # mean_embedding = mean_embedding.cpu().numpy()

    # 输出结果（句子的embedding）
    # print(mean_embedding.shape)
    return mean_embedding

if __name__ == '__main__':

    # Use repo-relative paths instead of hardcoded absolute paths
    base_dir = os.path.dirname(os.path.abspath(__file__))
    output_dir = os.path.join(base_dir, 'output')
    os.makedirs(output_dir, exist_ok=True)

    desfile = os.path.join(output_dir, 'DB-description.jsonl')
    outfile = os.path.join(base_dir, 'DB_drug_emb.npz')
    get_emb(model, desfile, outfile)
    exit()
