import sys
import os
sys.path.append(os.path.abspath('./'))
print(sys.path)

# os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # 使用第一个GPU
from configuration_llapa import LlapaConfig
import torch
from transformers.models.auto import AutoModel, AutoTokenizer
####################################################################################################################################

ANSWER_PREFIX = " \n### Assistant: "

ROOT_DIR = ''

def get_model(args):
    device = 'cuda:0'
    BASE_DIR = f'{ROOT_DIR}/model_weights/stage2_llama_{args.task}/checkpoint-{args.ckpt}'

    # ---------------------------------  加载模型  --------------------------------------
    model_path = f'{ROOT_DIR}/checkpoints/llapa_v1'
    config = LlapaConfig.from_pretrained(model_path)

    protein_tokenizer = AutoTokenizer.from_pretrained(config.protein_config._name_or_path)
    protein_encoder = AutoModel.from_pretrained(config.protein_config._name_or_path)

    protein_encoder.to(device)

    return protein_encoder, protein_tokenizer

if __name__ == "__main__":
    import pandas as pd
    from load_data import load_proteins
    from utils import get_parse, setup_seed, setLogger

    args = get_parse()
    setup_seed(args.seed)

    device = 'cuda:2'
    dtype = torch.bfloat16

    protein_encoder, protein_tokenizer = get_model(args)
    protein_encoder.to(dtype)
    protein_encoder.to(device)


    proteins = load_proteins(args.task)

    print("task: ", args.task)
    print("==========> number of proteins", len(proteins))

    batch = 16

    s = 0
    e = batch

    embedding_list = []
    while s < len(proteins):
        proteins_list = proteins[s:e]

        protein_input_dict = protein_tokenizer(proteins_list, padding='max_length', max_length=1024, truncation=True)
        protein_input_ids = protein_input_dict['input_ids']
        protein_attention_mask = protein_input_dict['attention_mask']

        protein_input_ids = torch.tensor(protein_input_ids).to(device)
        protein_attention_mask = torch.tensor(protein_attention_mask).to(torch.bool).to(device)

        with torch.no_grad():
            outputs = protein_encoder(input_ids=protein_input_ids, attention_mask=protein_attention_mask)

        protein_features = outputs[0].mean(dim=1)
        protein_features = protein_features.to(dtype).to('cpu')

        embedding_list.append(protein_features)

        s = e
        e = e + batch

    embeddings = torch.cat(embedding_list)
    print("==========> shape of embeddings", embeddings.shape)


    torch.save(embeddings, f'{ROOT_DIR}/dataset/processed_data/{args.task}_tensor.pt')

    # outputs = pd.DataFrame(data=outputs, columns=['question', 'pred', 'truth'])
    # outputs.to_csv(f'./outputs/outputs_{args.task}_{args.ckpt}.csv', index=False)

    
