# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd
import torch
import pickle
from tqdm import tqdm
# tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
# model = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t12_35M_UR50D")

tokenizer = AutoTokenizer.from_pretrained("/data1/tianlong/LLaVA_ckpt/esm2_t33_650M_UR50D")
model = AutoModelForMaskedLM.from_pretrained("/data1/tianlong/LLaVA_ckpt/esm2_t33_650M_UR50D")


csv_files = [
    'datasets/split100.csv',
    'datasets/halogenase.csv',
    'datasets/multi.csv',
    'datasets/new.csv',
    'datasets/price.csv'
]

with open('datasets/SMILES/first_features.pkl', 'rb') as fin:
    first_info = pickle.load(fin)

id_list = first_info['id']
first_feature_list = first_info['feature']
n_feature = len(first_feature_list)
all_feature = torch.zeros((n_feature, 768))

fit_dim = 768
for i in range(n_feature):
    all_feature[i][:fit_dim] = first_feature_list[i][0,:fit_dim]
    
all_feature = all_feature.to('cuda')

protein_id2smiles_id = {}
model = model.to('cuda')
# fit_dim=480
with torch.no_grad():
    for cf in csv_files:
        protein_data = pd.read_csv(cf, sep='\t')
        bar = tqdm(range(len(protein_data)))
        for i in bar:
            entry_id = protein_data['Entry'][i]
            prot_seq = protein_data['Sequence'][i]
            tokens = tokenizer(prot_seq, return_tensors="pt", max_length=512, padding=True, truncation=True)
            tokens = tokens.to('cuda')
            prot_feature = model(**tokens, output_hidden_states=True).hidden_states[-1][0,0]
            prot_feature = prot_feature.unsqueeze(0)
            # print(prot_feature.unsqueeze(0).shape)
            cossim = torch.nn.functional.cosine_similarity(prot_feature[:,:fit_dim], all_feature[:, :fit_dim])
            cos_logit, cos_index = cossim.topk(4)
            cos_logit = cos_logit.tolist()
            cos_index = cos_index.tolist()
            protein_id2smiles_id[entry_id] = {"cos_logit": cos_logit, "cos_index": cos_index}
            
            # print(cossim.topk(4)[1].tolist(), protein_data['EC number'][i])
            # if i > 10:
            #     break
            
with open("datasets/protid_to_smilesid.pkl", 'wb') as fout:
    pickle.dump(protein_id2smiles_id, fout)