import json
import glob

# Define the file patterns for each type of dataset
patterns = [
    'for_finetune_beam.json',
]
merged_data = []



import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Descriptors
from rdkit.Chem import QED
from rdkit.Chem import GraphDescriptors
from rdkit.Chem import Lipinski
from rdkit import RDLogger
import numpy as np
RDLogger.DisableLog('rdApp.warning')
RDLogger.DisableLog('rdApp.error')


import selfies as sf
from nltk.translate.bleu_score import corpus_bleu
def get_smiles(smiles_list, ground_truth):
    valid_smiles = []
    for smiles in smiles_list:
        smiles = smiles.replace('The','').replace('molecule','')
        if len(smiles) > 0:
            try:
                mol = Chem.MolFromSmiles(sf.decoder(smiles))
                bleu = corpus_bleu([[[c for c in sf.decoder(ground_truth)]]], [[c for c in sf.decoder(smiles)]])
                if mol and bleu > 0.7:
                    valid_smiles.append(smiles)
            except:
                continue
    valid_smiles = list(set(valid_smiles)-set([ground_truth]))
    valid_smiles.insert(0, ground_truth)
    return valid_smiles


import re
merged_data = []
def read_and_merge_files(pattern):
    files = glob.glob(pattern)
    for file in files:
        with open(file, 'r') as f:
            data = json.load(f)
            if isinstance(data, list):  
                merged_data.extend(data)
            else:
                merged_data.append(data)
    for entry in merged_data:
        entry['output'] = '<bom>'+'<eom> <bom>'.join(get_smiles(re.findall(r"<bom>(.*?)<eom>", entry['output']), entry['ground_truth'].replace('<bom>','').replace('<eom>',"")))+'<eom>'


for pattern in patterns:
    read_and_merge_files(pattern)

# Save the merged data to a new JSON file
with open('SFT_dataset.json', 'w') as f:
    json.dump(merged_data, f, indent=4)