from transformers import AutoTokenizer, AutoModelForMaskedLM
import json 
from tqdm import tqdm
import pickle
import pubchempy as pcp
from collections import defaultdict
import time
import concurrent.futures

tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM")

model = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-77M-MLM")

with open('./idx_map.json') as f:
    idx_map = json.load(f)

molecule_idx2cid = {value: key for key, value in idx_map['molecule'].items()}

with open('./cid2emb.pkl', 'rb') as f:
    cid2emb = pickle.load(f)

import urllib

def get_embeddings(cid):
    for i in range(5):
        try:
            # Define the SMILES string
            p = pcp.get_properties('IsomericSMILES', cid, 'cid')
            smiles_string = p[0]['IsomericSMILES']

            # Tokenize the SMILES string
            tokens = tokenizer(smiles_string, return_tensors="pt")
            if len(tokens['input_ids'][0]) > 500:
                tokens = tokenizer(smiles_string[:500], return_tensors="pt")

            # Pass the tokens through the model
            outputs = model(**tokens)
            # Extract the embeddings
            embeddings = outputs.logits
            # Calculate the mean embedding
            mean_embedding = embeddings.mean(dim=1)

            return cid, smiles_string, mean_embedding[0].detach().numpy()

        except (pcp.PubChemHTTPError, urllib.error.URLError) as e:
            # If there is a PubChemHTTPError or URLError, wait for 1 second and try again
            # print(f"Error occurred for CID {cid}. Retrying... ({i+1}/5)")
            time.sleep(1)
    return cid, None, None

# Use a ThreadPoolExecutor to parallelize the processing
with concurrent.futures.ThreadPoolExecutor() as executor:
    cids = list(molecule_idx2cid.values())

    for cid, smiles_string, mean_embedding in tqdm(executor.map(get_embeddings, cids), total=len(cids)):
        if cid in cid2emb:
            continue
        if smiles_string is not None and mean_embedding is not None:
            cid2emb[cid]['smiles'] = smiles_string
            cid2emb[cid]['emb'] = mean_embedding
        
        if len(cid2emb) % 10000 == 0 and len(cid2emb) > 0:
            with open('./cid2emb_new.pkl', 'wb') as f:
                pickle.dump(cid2emb, f)

with open('./cid2emb_new.pkl', 'wb') as f:
    pickle.dump(cid2emb, f)

