#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import sys
import time
import joblib
#import tdc # expect some errors maybe?
from tdc import Oracle
import numpy as np
import pandas as pd
import torch
import yaml
from rdkit import rdBase
from rdkit.Chem import MolFromSmiles
from rdkit.Chem import Draw
from rdkit import Chem
import random
from torch import nn
from rdkit.Chem import Descriptors, AllChem
from rdkit.Chem import QED
from rdkit.Chem.Scaffolds import MurckoScaffold
import math
import selfies as sf
from data_loader import \
    multiple_selfies_to_hot, multiple_smile_to_hot
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import ExpectedImprovement,UpperConfidenceBound
from botorch.optim.optimize import optimize_acqf_mixed
from botorch.optim import optimize_acqf
from botorch.acquisition import AcquisitionFunction
import xgboost as xgb
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from gpytorch.kernels import ScaleKernel, RBFKernel, ProductKernel
from botorch.models.kernels.categorical import CategoricalKernel
import gpytorch
import argparse
import subprocess
import re
import os
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem import AllChem
from rdkit.DataStructs import TanimotoSimilarity, DiceSimilarity

import warnings

# Suppress all warnings
warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser(description='beta ld')
parser.add_argument('beta', type=float, help='The digit value')
parser.add_argument('z_dim', type=int, help='The digit value')
parser.add_argument('target_property', type=str, help='The digit value')

args = parser.parse_args()
beta = args.beta
z_dim = args.z_dim
target_property = args.target_property



folder_path = f"lsbo_cvae/gp_training/"
file_name = f"gp_cvae_y_ld_{z_dim}_beta_{beta}_target_property_{target_property}.csv"
file_path = os.path.join(folder_path, file_name)
if os.path.exists(file_path):
    print(f"The folder {folder_path} exists.")
else:
    command = ['python', 'get_cvae_gp_training.py', str(beta), str(z_dim),target_property]
    subprocess.run(command)
file_name = f"gp_cvae_y_sorted_ld_{z_dim}_beta_{beta}_target_property_{target_property}.csv"
file_path = os.path.join(folder_path, file_name)
if os.path.exists(file_path):
    print(f"The folder {folder_path} exists.")
else:
    command = ['python', 'get_cvae_gp_training_top100.py', str(beta),str(z_dim),target_property]
    subprocess.run(command)

rdBase.DisableLog('rdApp.error')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the VAE encoder and decoder as provided
class VAE_Encoder_Fully_connected_unit(nn.Module):
    def __init__(self, z_n, input_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(input_dim, 4),
            nn.Sigmoid(),
        )
        self.fc_mean = nn.Linear(4, z_n)
        self.fc_log_var = nn.Linear(4, z_n)

    def forward(self, x):
        x = self.fc(x)
        mean = self.fc_mean(x)
        log_var = self.fc_log_var(x)
        z = self.reparameterize(mean, log_var)
        return z, mean, log_var

    @staticmethod
    def reparameterize(mean, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mean)



def _make_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)


def save_models(encoder, decoder, epoch):
    out_dir = './saved_models/{}'.format(epoch)
    _make_dir(out_dir)
    torch.save(encoder, '{}/E.pt'.format(out_dir))
    torch.save(decoder, '{}/D.pt'.format(out_dir))




class ConditionalBatchNorm1d(nn.Module):
    def __init__(self, num_features, num_conditions):
        super(ConditionalBatchNorm1d, self).__init__()
        self.num_features = num_features
        self.bn = nn.BatchNorm1d(num_features, affine=False)
        self.gamma_embed = nn.Linear(num_conditions, num_features)
        self.beta_embed = nn.Linear(num_conditions, num_features)

    def forward(self, x, condition):
        out = self.bn(x)        
        # Generate gamma and beta from the condition vector
        gamma = self.gamma_embed(condition)  # Shape: [batch_size, num_features]
        beta = self.beta_embed(condition)    # Shape: [batch_size, num_features]
        
        # Reshape gamma and beta to match the shape of out
        gamma = gamma.view(-1, self.num_features)
        beta = beta.view(-1, self.num_features)
        
        # Apply conditional batch normalization
        out = gamma * out + beta
        return out


class VAEEncoder(nn.Module):
    def __init__(self, condition_embedding_dim, in_dimension, layer_1d, layer_2d, layer_3d, latent_dimension):
        super(VAEEncoder, self).__init__()
        self.latent_dimension = latent_dimension

        self.fc1 = nn.Linear(in_dimension + condition_embedding_dim, layer_1d)
        self.cbn1 = ConditionalBatchNorm1d(layer_1d, condition_embedding_dim)
        
        self.fc2 = nn.Linear(layer_1d, layer_2d)
        self.cbn2 = ConditionalBatchNorm1d(layer_2d, condition_embedding_dim)
        
        self.fc3 = nn.Linear(layer_2d, layer_3d)
        self.cbn3 = ConditionalBatchNorm1d(layer_3d, condition_embedding_dim)
        
        self.encode_mu = nn.Linear(layer_3d, latent_dimension)
        self.encode_log_var = nn.Linear(layer_3d, latent_dimension)

    @staticmethod
    def reparameterize(mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, x, condition):
        condition = condition.view(x.size(0), -1)
        combined_input = torch.cat((x, condition), dim=1)
        
        x = TF.relu(self.cbn1(self.fc1(combined_input), condition))
        x = TF.relu(self.cbn2(self.fc2(x), condition))
        x = TF.relu(self.cbn3(self.fc3(x), condition))
        
        mu = self.encode_mu(x)
        log_var = self.encode_log_var(x)
        z = self.reparameterize(mu, log_var)
        return z, mu, log_var


class VAEDecoder(nn.Module):
    def __init__(self, latent_dimension, gru_stack_size, gru_neurons_num, condition_embedding_dim, out_dimension):
        super(VAEDecoder, self).__init__()
        self.latent_dimension = latent_dimension + condition_embedding_dim
        self.gru_stack_size = gru_stack_size
        self.gru_neurons_num = gru_neurons_num

        self.decode_RNN = nn.GRU(
            input_size=self.latent_dimension,
            hidden_size=gru_neurons_num,
            num_layers=gru_stack_size,
            batch_first=False)
        
        self.decode_cbn = ConditionalBatchNorm1d(gru_neurons_num, condition_embedding_dim)
        self.decode_FC = nn.Linear(gru_neurons_num, out_dimension)

    def init_hidden(self, batch_size=1):
        weight = next(self.parameters())
        return weight.new_zeros(self.gru_stack_size, batch_size, self.gru_neurons_num)

    def forward(self, z, hidden, condition):
        l1, hidden = self.decode_RNN(z, hidden)
        
        # Apply Conditional Batch Normalization
        condition = condition.view(z.size(1), -1)
        l1 = self.decode_cbn(l1, condition.unsqueeze(0).expand(l1.size(0), -1, -1))
        
        decoded = self.decode_FC(l1)
        return decoded, hidden


def is_correct_smiles(smiles):
    """
    Using RDKit to calculate whether molecule is syntactically and
    semantically valid.
    """
    if smiles == "":
        return False

    try:
        return MolFromSmiles(smiles, sanitize=True) is not None
    except Exception:
        return False


def sample_latent_space_x(vae_encoder, vae_decoder, sample_len):
    vae_encoder.eval()
    vae_decoder.eval()

    gathered_atoms = []

    fancy_latent_point = torch.randn(1, 1, vae_encoder.latent_dimension,
                                     device=device)
    hidden = vae_decoder.init_hidden()

    # runs over letters from molecules (len=size of largest molecule)
    for _ in range(sample_len):
        out_one_hot, hidden = vae_decoder(fancy_latent_point, hidden)

        out_one_hot = out_one_hot.flatten().detach()
        soft = nn.Softmax(0)
        out_one_hot = soft(out_one_hot)

        out_index = out_one_hot.argmax(0)
        gathered_atoms.append(out_index.data.cpu().tolist())

    vae_encoder.train()
    vae_decoder.train()

    return gathered_atoms


def sample_latent_space_topk(vae_encoder, vae_decoder, sample_len, k=5):
    vae_encoder.eval()
    vae_decoder.eval()

    gathered_atoms = []

    fancy_latent_point = torch.randn(1, 1, vae_encoder.latent_dimension, device=device)
    hidden = vae_decoder.init_hidden()

    # runs over letters from molecules (len=size of largest molecule)
    for _ in range(sample_len):
        out_one_hot, hidden = vae_decoder(fancy_latent_point, hidden)
        out_one_hot = out_one_hot.flatten().detach()
        soft = nn.Softmax(dim=0)
        probabilities = soft(out_one_hot)

        # Select the indices of the top k probabilities
        topk_probs, topk_indices = torch.topk(probabilities, k)

        # Sample from the top k indices based on their probabilities
        out_index = torch.multinomial(topk_probs, 1)

        # Get the actual atom index
        chosen_atom_index = topk_indices[out_index].item()
        gathered_atoms.append(chosen_atom_index)

    vae_encoder.train()
    vae_decoder.train()

    return gathered_atoms

def sample_latent_space(vae_encoder, vae_decoder, condition_vector, candidate, sample_len, k=1):
    vae_encoder.eval()
    vae_decoder.eval()

    gathered_atoms = []
    combined_input = torch.cat((candidate.to(device), condition_vector.to(device)), dim=1).unsqueeze(0)
    hidden = vae_decoder.init_hidden(batch_size=100)

    # runs over letters from molecules (len=size of largest molecule)
    for _ in range(sample_len):

        combined_input_rpt = combined_input.repeat(1, 100, 1)

        # Convert tensor2 from shape [1, 2] to [100, 2]
        condition_vector_rpt = condition_vector.repeat(100, 1)
        out_one_hot, hidden = vae_decoder(combined_input_rpt, hidden, condition_vector_rpt)        
        out_one_hot = out_one_hot.flatten().detach()
        out_one_hot = out_one_hot[:18]
        soft = nn.Softmax(dim=0)
        probabilities = soft(out_one_hot)

        # Select the indices of the top k probabilities
        _, topk_indices = torch.topk(probabilities, k)

        # Randomly select one from the top k indices with equal probability
        out_index = torch.randint(0, k, (1,)).item()  # Random index from 0 to k-1
        chosen_atom_index = topk_indices[out_index].item()
        gathered_atoms.append(chosen_atom_index)

    vae_encoder.train()
    vae_decoder.train()

    return gathered_atoms

def latent_space_quality(vae_encoder, vae_decoder, condition_vector, type_of_encoding,
                         alphabet, sample_num, sample_len):
    total_correct = 0
    all_correct_molecules = set()
    print(f"latent_space_quality:"
          f" Take {sample_num} samples from the latent space")

    #TODO: remove for loop
    #TODO: let AF pick the latent variable
    #TODO: send latent variable to sample_latent_space function
    #TODO: add substructure here
    #TODO: evaluate here
    #TODO: update GP
    #TODO: maybe move this function to main loop
    for _ in range(1, sample_num + 1):

        molecule_pre = ''
        gathered_atoms, latent_var = sample_latent_space(vae_encoder, vae_decoder, condition_vector, sample_len)
        for i in gathered_atoms:
            molecule_pre += alphabet[i]
        molecule = molecule_pre.replace(' ', '')

        if type_of_encoding == 1:  # if SELFIES, decode to SMILES
            molecule = sf.decoder(molecule)

        if is_correct_smiles(molecule):
            total_correct += 1
            all_correct_molecules.add(molecule)

    return total_correct, len(all_correct_molecules), all_correct_molecules


def get_selfie_and_smiles_encodings_for_dataset(file_path):
    """
    Returns encoding, alphabet and length of largest molecule in SMILES and
    SELFIES, given a file containing SMILES molecules.

    input:
        csv file with molecules. Column's name must be 'smiles'.
    output:
        - selfies encoding
        - selfies alphabet
        - longest selfies string
        - smiles encoding (equivalent to file content)
        - smiles alphabet (character based)
        - longest smiles string
    """

    df = pd.read_csv(file_path)

    smiles_list = np.asanyarray(df.smiles)

    smiles_alphabet = list(set(''.join(smiles_list)))
    smiles_alphabet.append(' ')  # for padding

    largest_smiles_len = len(max(smiles_list, key=len))

    print('--> Translating SMILES to SELFIES...')
    selfies_list = list(map(sf.encoder, smiles_list))

    all_selfies_symbols = sf.get_alphabet_from_selfies(selfies_list)
    all_selfies_symbols.add('[nop]')
    selfies_alphabet = list(all_selfies_symbols)

    largest_selfies_len = max(sf.len_selfies(s) for s in selfies_list)

    print('Finished translating SMILES to SELFIES.')

    return selfies_list, selfies_alphabet, largest_selfies_len, \
           smiles_list, smiles_alphabet, largest_smiles_len

def add_substructure(input_smiles, atom_index, substructure_smiles):
    try:
        mol = Chem.MolFromSmiles(input_smiles)
        if mol is None:
            raise ValueError("Invalid input SMILES")
        
        substructure_mol = Chem.MolFromSmiles(substructure_smiles)
        if substructure_mol is None:
            raise ValueError("Invalid substructure SMILES")
        
        edited_mol = Chem.RWMol(mol)
        combo = Chem.CombineMols(edited_mol, substructure_mol)
        combo_mol = Chem.RWMol(combo)
        
        # Calculate the new atom index in the combined molecule
        new_atom_index = edited_mol.GetNumAtoms()
        
        combo_mol.AddBond(atom_index, new_atom_index, Chem.BondType.SINGLE)
        Chem.SanitizeMol(combo_mol)
        return Chem.MolToSmiles(combo_mol)
    except Exception as e:
        print(f"Failed to add substructure: {e}")
        return input_smiles

def replace_substructure(input_smiles, atom_index, substructure_smiles):
    try:
        mol = Chem.MolFromSmiles(input_smiles)
        if mol is None:
            raise ValueError("Invalid input SMILES")

        substructure_mol = Chem.MolFromSmiles(substructure_smiles)
        if substructure_mol is None:
            raise ValueError("Invalid substructure SMILES")

        # Convert to editable molecule
        edited_mol = Chem.RWMol(mol)

        # Remove the atom at the specified index
        # Collect bonds and connected atoms before removal
        bonds = list(edited_mol.GetAtomWithIdx(atom_index).GetBonds())
        connected_atom_idxs = [bond.GetOtherAtomIdx(atom_index) for bond in bonds]
        
        edited_mol.RemoveAtom(atom_index)

        # Add the substructure
        combo = Chem.CombineMols(edited_mol, substructure_mol)
        combo_mol = Chem.RWMol(combo)

        # Assume we want to connect the original connected atoms to the first atom of the substructure
        new_atom_index = edited_mol.GetNumAtoms()  # New atom index where the substructure starts

        # Add bonds from the original connected atoms to the new substructure atom
        for idx in connected_atom_idxs:
            combo_mol.AddBond(idx, new_atom_index, Chem.BondType.SINGLE)

        Chem.SanitizeMol(combo_mol)
        return Chem.MolToSmiles(combo_mol)
    except Exception as e:
        print(f"Failed to replace substructure: {e}")
        return input_smiles


def get_atom_features(atom):
    """Extract features from a single atom and return them as a dictionary."""
    return {
        'AtomicNum': atom.GetAtomicNum(),
        'Hybridization': int(atom.GetHybridization()),
        'Valence': atom.GetTotalValence(),
        'FormalCharge': atom.GetFormalCharge(),
        'Degree': atom.GetDegree(),
        'IsInRing': int(atom.IsInRing())
    }

def draw_and_save_molecules(smiles_list, directory):
    for i, smiles in enumerate(smiles_list):
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            img = Draw.MolToImage(mol)
            img.save(f"{directory}/molecule_{i+1}.png")

def find_atoms_with_extra_bond(mol):
    atom_ids = []
    for atom in mol.GetAtoms():
        if can_have_extra_bond(mol, atom.GetIdx()):
            atom_ids.append(atom.GetIdx())
    return atom_ids

def can_have_extra_bond(mol, atom_idx):
    atom = mol.GetAtomWithIdx(atom_idx)
    degree = atom.GetDegree()
    total_valence = atom.GetTotalValence()
    print(f"degree of this atom is {degree} and its valence is {total_valence}")
    return degree < total_valence

def optimize_and_map(bounds, EI, atom_ids_with_extra_bond):
    # Optimization
    candidate, _ = optimize_acqf(
        acq_function=EI,
        bounds=bounds,
        q=1,
        num_restarts=10,
        raw_samples=100,
    )

    # Convert the last element from index to discrete atom ID
    #atom_id_index = int(candidate[0, -1].round().item())  # Round and convert to integer
    atom_id_index = int(math.floor(candidate[0, -1].item()))
    #TODO: round the atom id within the candidate vector as well!!!!!!!!
    candidate[0, -1] = torch.floor(candidate[0, -1])
    print(f'CANDIDATE ATOM ID IS: {atom_id_index}')
    atom_id = atom_ids_with_extra_bond[atom_id_index]  # Map to actual atom ID
    return candidate, atom_id

def get_scaffold(smiles):
    # Convert the SMILES string to an RDKit molecule object
    mol = Chem.MolFromSmiles(smiles)
    
    if not mol:
        return "Invalid SMILES string"
    
    # Get the Bemis-Murcko scaffold for the molecule
    scaffold = MurckoScaffold.GetScaffoldForMol(mol)
    
    # Convert the scaffold to a SMILES string
    scaffold_smiles = Chem.MolToSmiles(scaffold)
    return scaffold_smiles


def eval_similarity(original_smiles, modified_smiles, similarity_threshold=0.7):
    """
    Determines whether the LSBO process should stop based on molecular similarity.

    Args:
        original_smiles (str): SMILES representation of the original molecule.
        modified_smiles (str): SMILES representation of the modified molecule.
        similarity_threshold (float): The lower limit of similarity to stop the optimization.

    Returns:
        bool: True if the optimization should stop, False otherwise.
    """
    # Convert SMILES to RDKit molecule objects
    original_mol = Chem.MolFromSmiles(original_smiles)
    modified_mol = Chem.MolFromSmiles(modified_smiles)

    # Generate Morgan fingerprints
    fp_original = AllChem.GetMorganFingerprintAsBitVect(original_mol, radius=1, nBits=2048)
    fp_modified = AllChem.GetMorganFingerprintAsBitVect(modified_mol, radius=1, nBits=2048)

    # Calculate Tanimoto similarity
    similarity = DiceSimilarity(fp_original, fp_modified)
    
    # Check if similarity is below the threshold
    if similarity >= similarity_threshold:
        return True, similarity  # Stop the optimization
    else:
        return False, similarity  # Continue the optimization

def calculate_molecular_weight(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        mol_weight = rdMolDescriptors.CalcExactMolWt(mol)
        return mol_weight
    else:
        return "Invalid SMILES string"

def main():
    content = open('logfile.dat', 'w')
    content.close()
    content = open('results.dat', 'w')
    content.close()

    if os.path.exists("settings_cvae.yml"):
        settings = yaml.safe_load(open("settings_cvae.yml", "r"))
    else:
        print("Expected a file settings.yml but didn't find it.")
        return

    print('--> Acquiring data...')
    type_of_encoding = settings['data']['type_of_encoding']
    file_name_smiles = settings['data']['smiles_file']

    print('Finished acquiring data.')

    if type_of_encoding == 0:
        print('Representation: SMILES')
        _, _, _, encoding_list, encoding_alphabet, largest_molecule_len = \
            get_selfie_and_smiles_encodings_for_dataset(file_name_smiles)

        print('--> Creating one-hot encoding...')
        data = multiple_smile_to_hot(encoding_list, largest_molecule_len,
                                     encoding_alphabet)
        print('Finished creating one-hot encoding.')

    elif type_of_encoding == 1:
        print('Representation: SELFIES')
        encoding_list, encoding_alphabet, largest_molecule_len, _, _, _ = \
            get_selfie_and_smiles_encodings_for_dataset(file_name_smiles)
        encoding_alphabet = ['[=N]', '[Branch1]', '[#N]', '[=Ring1]', '[=C]', '[=Branch1]', '[NH1]', '[#Branch1]', '[Ring2]', '[=Branch2]', '[nop]', '[C]', '[=O]', '[N]', '[#C]', '[Ring1]', '[F]', '[O]']
        

        print('--> Creating one-hot encoding...')
        data = multiple_selfies_to_hot(encoding_list, largest_molecule_len,
                                       encoding_alphabet)
        print('Finished creating one-hot encoding.')

    else:
        print("type_of_encoding not in {0, 1}.")
        return

    len_max_molec = data.shape[1]
    len_alphabet = data.shape[2]
    len_max_mol_one_hot = len_max_molec * len_alphabet

    print(' ')
    print(f"Alphabet has {len_alphabet} letters, "
          f"largest molecule is {len_max_molec} letters.")

    data_parameters = settings['data']
    batch_size = data_parameters['batch_size']

    encoder_parameter = settings['encoder']
    encoder_parameter['latent_dimension'] = z_dim
    decoder_parameter = settings['decoder']
    decoder_parameter['latent_dimension'] = z_dim
    training_parameters = settings['training']
    training_parameters['KLD_alpha'] = beta

    vae_encoder = VAEEncoder(condition_embedding_dim=training_parameters['condition_embedding_dim'],in_dimension=len_max_mol_one_hot,
                             **encoder_parameter).to(device)
    vae_decoder = VAEDecoder(**decoder_parameter,condition_embedding_dim=training_parameters['condition_embedding_dim'],
                             out_dimension=len(encoding_alphabet)).to(device)
    train_metrics = pd.read_csv(f'cvae_trials/results_cvae/results_cvae_minmax_cond_batch_n_ld_{z_dim}_beta_{beta}.csv')
    epoch_num = 137 #train_metrics.iloc[-1,1]
    print(f'THE EPOCH NUM IS {epoch_num}')
    vae_encoder.load_state_dict(torch.load(f'cvae_trials/saved_models_attention_cond_batch_n_at_least_20/{epoch_num}/E_ld_{z_dim}_beta_{beta}.pt', map_location=torch.device(device)))
    vae_decoder.load_state_dict(torch.load(f'cvae_trials/saved_models_attention_cond_batch_n_at_least_20/{epoch_num}/D_ld_{z_dim}_beta_{beta}.pt', map_location=torch.device(device)))
    
    print(f"VAE ENCODER LATENT DIM IS {vae_encoder.latent_dimension}")
    input_dim = 6 # Number of features
    z_n = 2 # Latent dimension size
    encoder = VAE_Encoder_Fully_connected_unit(z_n, input_dim).to(device)
    encoder.load_state_dict(torch.load('embeddings_encoder_minmax.pt', map_location=torch.device(device)))
    scaler = joblib.load('minmaxscaler.joblib')
    df = pd.read_csv('paper_prep_files/target_mols.csv')
    scaffolds = df['scaffold']
    results = []
    seeds = [0,1,2,3,4,5,6,7,8,9]
    sorted_results_df = pd.read_csv(f'lsbo_cvae/gp_training/gp_cvae_y_sorted_ld_{z_dim}_beta_{beta}_target_property_{target_property}.csv')
    atom_ids_for_gp = pd.read_csv(f'lsbo_cvae/gp_training/gp_cvae_selected_atom_ids_ld_{z_dim}_beta_{beta}_target_property_{target_property}.csv') 
    print(atom_ids_for_gp.head)
    print(atom_ids_for_gp.shape)
    atom_ids_np = atom_ids_for_gp.iloc[:, 1].to_numpy()
    atom_ids_tensor = torch.tensor(atom_ids_np, dtype=torch.int).view(-1, 1).to(device)

    selected_latent_points = torch.load(f"lsbo_cvae/gp_training/gp_cvae_x_selected_ld_{z_dim}_beta_{beta}_target_property_{target_property}.pt", map_location=torch.device(device))
    #selected_atom_features_embed_all = torch.load(f"lsbo_cvae/gp_training/gp_cvae_x_atom_features_selected_ld_{z_dim}_beta_{beta}_target_property_{target_property}.pt")
    initial_X = torch.cat((selected_latent_points, atom_ids_tensor), dim=1)
    initial_Y = sorted_results_df['reward'].values
    #bounds = torch.tensor([[-3.0] * (z_dim), [3.0] * (z_dim)])
    #bounds = torch.tensor([[-3.0] * z_dim + [0.95] * 2, [3.0] * z_dim + [1.0] * 6])
    folder_name = 'lsbo_cvae/lsbo_results/lsbo_paper_clasmo_review_oracle_cnt_april_11/'
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    folder_name = f'lsbo_cvae/lsbo_results/lsbo_paper_clasmo_review_oracle_cnt_april_11/{target_property}'
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    input_molecule_counter = 0
    for scaffold_input in scaffolds:
        original_scaffold = scaffold_input
        for seed in seeds:
            resulting_smiles = original_scaffold
            scaffold_input = original_scaffold
            scaffold_smiles_beginning = original_scaffold
            atom_ids_with_extra_bond = find_atoms_with_extra_bond(Chem.MolFromSmiles(scaffold_input))
            print(f'initial atom_ids_with_extra_bond is {atom_ids_with_extra_bond}')
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
            torch.use_deterministic_algorithms(True)
            results_seed = []
            train_X = torch.tensor(initial_X, dtype=torch.float64).to(device)
            print(f"train X shape is {train_X.shape}")
            train_Y = torch.tensor(initial_Y, dtype=torch.float64).view(-1, 1).to(device)
            categorical_cols = [train_X.shape[1] - 1]  # Assuming the last column is categorical
            categorical_kernel = CategoricalKernel(num_tasks=1)
            scale_cat_kernel = ScaleKernel(categorical_kernel)
            continuous_kernel = RBFKernel()
            scale_cont_kernel = ScaleKernel(continuous_kernel)
            # Combine the kernels
            product_kernel = ProductKernel(scale_cont_kernel, scale_cat_kernel)
            #standardize(train_Y)
            gp_model = MixedSingleTaskGP(train_X, train_Y, cat_dims=categorical_cols)
            gp_model.covar_module = product_kernel
            mll = ExactMarginalLogLikelihood(gp_model.likelihood.to(device), gp_model.to(device))
            fit_gpytorch_model(mll)
            
            #bounds = torch.tensor([[-3.0] * z_dim + [1],[3.0] * z_dim + [len(atom_ids_with_extra_bond) - 1]])
            continuous_bounds = [[-5.0, 5.0]] * z_dim
            #continuous_bounds = [[-6.0, 6.0]] * z_dim # previous versions

            discrete_bounds = [[0, (len(atom_ids_with_extra_bond)-1)]]
            bounds = torch.tensor(continuous_bounds + discrete_bounds).T
            print("Bounds for optimization:", bounds)
            best_similarity = 1
            oracle_name = target_property
            # Instantiate the oracle
            oracle = Oracle(name=oracle_name)
            j = 0
            #while j < 100: rejected paper version
            for j in range(0, 100):
                print(f"**** LSBO STEP {j} ****")
                EI = UpperConfidenceBound(gp_model, beta=3)
                #EI = UpperConfidenceBound(gp_model, beta=2) # previous versions

                candidate, atom_id = optimize_and_map(bounds, EI, atom_ids_with_extra_bond)
                mol = Chem.MolFromSmiles(scaffold_input)
                atom = mol.GetAtomWithIdx(atom_id)
                atom_features = get_atom_features(atom)
                atom_features = list(atom_features.values())
                feature_list = [atom_features]  # Wrap list in another list to make it 2D
                scaled_features = scaler.transform(feature_list)
                scaled_features = torch.tensor(scaled_features[0], dtype=torch.float).to(device)
                atom_features_embedded, _, _ = encoder(scaled_features.unsqueeze(0))
                molecule_pre = ''
                gathered_atoms = sample_latent_space(vae_encoder, vae_decoder, atom_features_embedded, candidate[:, :-1], len_max_molec, 1)
                for i in gathered_atoms:
                    molecule_pre += encoding_alphabet[i]
                molecule = molecule_pre.replace(' ', '')

                if type_of_encoding == 1:  # if SELFIES, decode to SMILES
                    molecule = sf.decoder(molecule)

                if is_correct_smiles(molecule):
                    resulting_smiles = add_substructure(scaffold_input, atom_id, molecule)
                    #resulting_smiles = replace_substructure(scaffold_input, atom_id, molecule)
                    if resulting_smiles != scaffold_input:
                        print('anan')
                        flag, similarity = eval_similarity(resulting_smiles, scaffold_smiles_beginning, 0.7)
                        print(f'similarity: {similarity}')
                        mwr = calculate_molecular_weight(resulting_smiles)
                        mwo = calculate_molecular_weight(scaffold_smiles_beginning)
                        print(f'mol weight of original is {mwo}, mol weight of updated is {mwr}')
                        mw_delta = (mwr - mwo)/mwo
                        print(f'perc change in mol weight is {mw_delta}')
                        if similarity >= 0.5 and similarity < 1: #rejected paper version
                        #if similarity >= 0.5 and similarity < 1:
                            if target_property == "logp":
                                eval = Descriptors.MolLogP(Chem.MolFromSmiles(resulting_smiles)) - Descriptors.MolLogP(Chem.MolFromSmiles(scaffold_input))
                            elif target_property == "qed":
                                eval = QED.qed(Chem.MolFromSmiles(resulting_smiles)) - QED.qed(Chem.MolFromSmiles(scaffold_input))
                            else:
                                eval = oracle(resulting_smiles) - oracle(scaffold_input)
                            print(f'reward is {eval} AT step {j}')
                            j = j+1
                            results.append({'Atom_ID': atom_id, 'opt_step':j, 'Input_SMILES': scaffold_input,'Input_SMILES_weight':mwo, 'Resulting_SMILES': resulting_smiles,'Resulting_SMILES_weight':mwr, 'reward': eval, 'similarity':similarity,  'mw_delta': mw_delta})
                            results_seed.append({'Atom_ID': atom_id, 'opt_step':j, 'Input_SMILES': scaffold_input,'Input_SMILES_weight':mwo, 'Resulting_SMILES': resulting_smiles,'Resulting_SMILES_weight':mwr, 'reward': eval, 'similarity':similarity, 'mw_delta': mw_delta})
                            if eval > 0:
                                best_similarity = similarity
                                scaffold_input = resulting_smiles
                                atom_ids_with_extra_bond = find_atoms_with_extra_bond(Chem.MolFromSmiles(scaffold_input))
                                discrete_bounds = [[0, (len(atom_ids_with_extra_bond)-1)]]
                                bounds = torch.tensor(continuous_bounds + discrete_bounds).T
                                print(f'****** SIMILARITY IS: {similarity} ***********')
                        else:
                            eval = -5
                    else:
                        eval = -7.5
                else:
                    eval = -10
                train_X = torch.cat([train_X, candidate.to(device)])
                train_Y = torch.cat([train_Y, torch.tensor([[eval]]).to(device)])
                gp_model = SingleTaskGP(train_X, train_Y)
                mll = ExactMarginalLogLikelihood(gp_model.likelihood, gp_model)
                fit_gpytorch_model(mll) 
                results_df = pd.DataFrame(results)
                # rejected paper folder lsbo_paper_clasmo_review_oracle_cnt_dec24
                results_df.to_csv(f"lsbo_cvae/lsbo_results/lsbo_paper_clasmo_review_oracle_cnt_april_11/{target_property}/lsbo_cvae_y_ld_{z_dim}_beta_{beta}_target_property_{target_property}_mol_cnt_{input_molecule_counter}_all_seeds_scaffold.csv")
                results_df = pd.DataFrame(results_seed)
                results_df.to_csv(f"lsbo_cvae/lsbo_results/lsbo_paper_clasmo_review_oracle_cnt_april_11/{target_property}/lsbo_cvae_y_ld_{z_dim}_beta_{beta}_target_property_{target_property}_mol_cnt_{input_molecule_counter}_seed_{seed}_scaffold.csv")
                if resulting_smiles != scaffold_input and best_similarity < 0.67:
                    print('tolerance threshold exceeded, procideire complete')
                    #break
        
        input_molecule_counter+=1

if __name__ == '__main__':
    try:
        main()
    except AttributeError:
        _, error_message, _ = sys.exc_info()
        print(error_message)
