from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd
from tqdm import tqdm
import pickle
import torch
import pathlib

tokenizer = AutoTokenizer.from_pretrained("/data1/xxx/LLaVA_ckpt/ChemBERTa-zinc-base-v1")
model = AutoModelForMaskedLM.from_pretrained("/data1/xxx/LLaVA_ckpt/ChemBERTa-zinc-base-v1")

smiles_tsv = pd.read_csv('notebook/rhea-reaction-smiles.tsv', sep='\t')

model = model.to('cuda')
# id2first = {}
# id2all = {}
max_len = 0

smiles_root=pathlib.Path("datasets/SMILES")

smiles_root.mkdir(exist_ok=True)

bar = tqdm(range(len(smiles_tsv)))

id_list = []
first_feauture_list = []

with torch.no_grad():
    for i in bar:
        smiles = smiles_tsv['SMILE'][i].strip()
        id = smiles_tsv['id'][i]
        max_len = max(max_len, len(smiles))
        input_ids = tokenizer(smiles, return_tensors="pt", max_length=512, padding=True, truncation=True)
        input_ids = input_ids.to('cuda')
        feature = model(**input_ids, output_hidden_states=True).hidden_states[-1].cpu()
        # print(id, input_ids, feature[-1].shape)
        id_list.append(id)
        first_feauture_list.append(feature[:, 0].clone())
        # id2all[id] = feature
        with open(smiles_root.joinpath(f"{id}.pkl"), 'wb') as fout:
            pickle.dump(feature, fout)
        
        bar.set_description(f"Max SMILES:{max_len}")

with open(smiles_root.joinpath(f"first_features.pkl"), 'wb') as fout:
        pickle.dump({'id': id_list, 'feature': first_feauture_list}, fout)