import torch, json
from core.settings import get_settings
from transformers import AutoModel, AutoTokenizer
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

settings = get_settings()

model_name = "pretrained" if "ChemBOMAS_Basemodel" in settings.model_path else "original"

# tokenizer = AutoTokenizer.from_pretrained(settings.model_path, local_files_only=True)
# if not tokenizer.pad_token:
#     tokenizer.pad_token = tokenizer.eos_token

# model = AutoModel.from_pretrained(settings.model_path, local_files_only=True)
# model.eval()
def text_to_embedding(text: str, model, tokenizer) -> torch.Tensor:
    """
    将输入文本转换为嵌入向量
    返回形状: [embedding_dim] (4096维)
    """
    inputs = tokenizer(text, max_length=3000, padding='longest', truncation=True, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
    
    # 3. 提取最后隐藏层状态 [batch_size, seq_len, hidden_size]
    last_hidden_states = outputs.last_hidden_state

    mask = inputs["attention_mask"].unsqueeze(-1)
    masked_embeddings = last_hidden_states * mask

    sum_embeddings = torch.sum(masked_embeddings, dim=1)
    num_tokens = torch.sum(mask, dim=1)
    sentence_embedding = sum_embeddings / num_tokens

    return sentence_embedding.squeeze(0)

def get_embedding(file_path, save_path=None) -> torch.Tensor:

    file_name = file_path.split("/")[-1].split(".")[0]

    data_maps = {}

    with open(file_path, "r") as f:
        json_data = json.load(f)
        for data in json_data:
            name = data["name"]
            attrs = list(data.keys())
            att_string = ""
            for attr in attrs:
                att_string += f"{attr}: {data[attr]}\n"
            data_maps[name] = text_to_embedding(att_string, model, tokenizer)

    # save data_maps
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        save_path = os.path.join(save_path, f"{model_name}_dry_{file_name}_embedding.pt")
        torch.save(data_maps, save_path)
    
    return data_maps

if __name__ == "__main__":
    file_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_emcoder/suzuki_base.json"
    save_path = "/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_emcoder/saved_data_maps"
    data_maps = get_embedding(file_path, save_path)
