import torch
from transformers import BertTokenizer, BertModel


def get_text_cond(param_names):
    print(param_names)
    param_names = [[param_names[0], param_names[1], param_names[2], param_names[3]],
                   [param_names[4], param_names[5], param_names[6], param_names[7]],
                   [param_names[8], param_names[9], param_names[10], param_names[11]],
                   [param_names[12], param_names[13], param_names[14], param_names[15]]]
    prompts = []
    for i in range(len(param_names)):
        prompts.append(' and '.join(param_names[i]))
    print(prompts)

    tokenizer = BertTokenizer.from_pretrained('/nfs196/hjc/pretrained_models/bert/')
    model = BertModel.from_pretrained('/nfs196/hjc/pretrained_models/bert/')

    # 使用分词器对文本进行批量编码
    inputs = tokenizer(prompts, return_tensors='pt', max_length=512, truncation=True, padding=True)

    # 将批量输入传入模型
    with torch.no_grad():
        outputs = model(**inputs)

    # 提取最后一层的隐藏状态
    last_hidden_states = outputs.last_hidden_state

    # 使用[CLS]标记的隐藏状态作为句子表示
    sentence_embeddings = last_hidden_states[:, 0, :]
    # sentence_embeddings = last_hidden_states.mean(dim=1)

    # 打印句子表示的形状
    print('sentence_embeddings.shape', sentence_embeddings.shape)
    return sentence_embeddings
