import streamlit as st
from rdkit import Chem
from rdkit.Chem import Draw, AllChem, QED, Descriptors
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import rdMolDescriptors
from rdkit.DataStructs import DiceSimilarity
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import joblib
import yaml
import math
import random
import selfies as sf
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import io
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import UpperConfidenceBound
from botorch.optim import optimize_acqf
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from gpytorch.kernels import ScaleKernel, RBFKernel, ProductKernel
from botorch.models.kernels.categorical import CategoricalKernel
import warnings
import os
from data_loader import \
    multiple_selfies_to_hot, multiple_smile_to_hot
warnings.filterwarnings('ignore')

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

beta = 0.000001
z_dim = 2
target_property = "qed"
# Define the VAE encoder and decoder as provided
class VAE_Encoder_Fully_connected_unit(nn.Module):
    def __init__(self, z_n, input_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(input_dim, 4),
            nn.Sigmoid(),
        )
        self.fc_mean = nn.Linear(4, z_n)
        self.fc_log_var = nn.Linear(4, z_n)

    def forward(self, x):
        x = self.fc(x)
        mean = self.fc_mean(x)
        log_var = self.fc_log_var(x)
        z = self.reparameterize(mean, log_var)
        return z, mean, log_var

    @staticmethod
    def reparameterize(mean, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mean)

class ConditionalBatchNorm1d(nn.Module):
    def __init__(self, num_features, num_conditions):
        super(ConditionalBatchNorm1d, self).__init__()
        self.num_features = num_features
        self.bn = nn.BatchNorm1d(num_features, affine=False)
        self.gamma_embed = nn.Linear(num_conditions, num_features)
        self.beta_embed = nn.Linear(num_conditions, num_features)

    def forward(self, x, condition):
        out = self.bn(x)
        gamma = self.gamma_embed(condition)
        beta = self.beta_embed(condition)
        gamma = gamma.view(-1, self.num_features)
        beta = beta.view(-1, self.num_features)
        out = gamma * out + beta
        return out

class VAEEncoder(nn.Module):
    def __init__(self, condition_embedding_dim, in_dimension, layer_1d, layer_2d, layer_3d, latent_dimension):
        super(VAEEncoder, self).__init__()
        self.latent_dimension = latent_dimension

        self.fc1 = nn.Linear(in_dimension + condition_embedding_dim, layer_1d)
        self.cbn1 = ConditionalBatchNorm1d(layer_1d, condition_embedding_dim)

        self.fc2 = nn.Linear(layer_1d, layer_2d)
        self.cbn2 = ConditionalBatchNorm1d(layer_2d, condition_embedding_dim)

        self.fc3 = nn.Linear(layer_2d, layer_3d)
        self.cbn3 = ConditionalBatchNorm1d(layer_3d, condition_embedding_dim)

        self.encode_mu = nn.Linear(layer_3d, latent_dimension)
        self.encode_log_var = nn.Linear(layer_3d, latent_dimension)

    @staticmethod
    def reparameterize(mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, x, condition):
        condition = condition.view(x.size(0), -1)
        combined_input = torch.cat((x, condition), dim=1)

        x = torch.relu(self.cbn1(self.fc1(combined_input), condition))
        x = torch.relu(self.cbn2(self.fc2(x), condition))
        x = torch.relu(self.cbn3(self.fc3(x), condition))

        mu = self.encode_mu(x)
        log_var = self.encode_log_var(x)
        z = self.reparameterize(mu, log_var)
        return z, mu, log_var

class VAEDecoder(nn.Module):
    def __init__(self, latent_dimension, gru_stack_size, gru_neurons_num, condition_embedding_dim, out_dimension):
        super(VAEDecoder, self).__init__()
        self.latent_dimension = latent_dimension + condition_embedding_dim
        self.gru_stack_size = gru_stack_size
        self.gru_neurons_num = gru_neurons_num

        self.decode_RNN = nn.GRU(
            input_size=self.latent_dimension,
            hidden_size=gru_neurons_num,
            num_layers=gru_stack_size,
            batch_first=False)

        self.decode_cbn = ConditionalBatchNorm1d(gru_neurons_num, condition_embedding_dim)
        self.decode_FC = nn.Linear(gru_neurons_num, out_dimension)

    def init_hidden(self, batch_size=1):
        weight = next(self.parameters())
        return weight.new_zeros(self.gru_stack_size, batch_size, self.gru_neurons_num)

    def forward(self, z, hidden, condition):
        l1, hidden = self.decode_RNN(z, hidden)
        condition = condition.view(z.size(1), -1)
        l1 = self.decode_cbn(l1, condition.unsqueeze(0).expand(l1.size(0), -1, -1))
        decoded = self.decode_FC(l1)
        return decoded, hidden

def is_correct_smiles(smiles):
    if smiles == "":
        return False
    try:
        return Chem.MolFromSmiles(smiles, sanitize=True) is not None
    except Exception:
        return False

def sample_latent_space(vae_encoder, vae_decoder, condition_vector, candidate, sample_len):
    # generate from given latent vector and condition vector
    vae_encoder.eval()
    vae_decoder.eval()

    gathered_atoms = []
    combined_input = torch.cat((candidate.to(device), condition_vector.to(device)), dim=1).unsqueeze(0)
    hidden = vae_decoder.init_hidden(batch_size=100)

    # runs over letters from molecules (len=size of largest molecule)
    for _ in range(sample_len):

        combined_input_rpt = combined_input.repeat(1, 100, 1)

        # Convert tensor2 from shape [1, 2] to [100, 2]
        condition_vector_rpt = condition_vector.repeat(100, 1)
        out_one_hot, hidden = vae_decoder(combined_input_rpt, hidden, condition_vector_rpt)        
        out_one_hot = out_one_hot.flatten().detach()
        out_one_hot = out_one_hot[:18]
        soft = nn.Softmax(dim=0)
        probabilities = soft(out_one_hot)
        chosen_atom_index = torch.argmax(probabilities).item()
        gathered_atoms.append(chosen_atom_index)
    vae_encoder.train()
    vae_decoder.train()

    return gathered_atoms

def get_atom_features(atom):
    return {
        'AtomicNum': atom.GetAtomicNum(),
        'Hybridization': int(atom.GetHybridization()),
        'Valence': atom.GetTotalValence(),
        'FormalCharge': atom.GetFormalCharge(),
        'Degree': atom.GetDegree(),
        'IsInRing': int(atom.IsInRing())
    }

def add_substructure(input_smiles, atom_index, substructure_smiles):
    try:
        mol = Chem.MolFromSmiles(input_smiles)
        if mol is None:
            raise ValueError("Invalid input SMILES")
        
        substructure_mol = Chem.MolFromSmiles(substructure_smiles)
        if substructure_mol is None:
            raise ValueError("Invalid substructure SMILES")
        
        edited_mol = Chem.RWMol(mol)
        combo = Chem.CombineMols(edited_mol, substructure_mol)
        combo_mol = Chem.RWMol(combo)
        
        new_atom_index = edited_mol.GetNumAtoms()
        combo_mol.AddBond(atom_index, new_atom_index, Chem.BondType.SINGLE)
        Chem.SanitizeMol(combo_mol)
        return Chem.MolToSmiles(combo_mol)
    except Exception:
        return input_smiles

def find_atoms_with_extra_bond(mol):
    atom_ids = []
    for atom in mol.GetAtoms():
        if can_have_extra_bond(mol, atom.GetIdx()):
            atom_ids.append(atom.GetIdx())
    return atom_ids

def can_have_extra_bond(mol, atom_idx):
    atom = mol.GetAtomWithIdx(atom_idx)
    degree = atom.GetDegree()
    total_valence = atom.GetTotalValence()
    return degree < total_valence

def optimize_and_map(bounds, UCB_acq, atom_ids_with_extra_bond):
    candidate, _ = optimize_acqf(
        acq_function=UCB_acq,
        bounds=bounds,
        q=1,
        num_restarts=10,
        raw_samples=500,
    )
    atom_id_index = int(math.floor(candidate[0, -1].item()))
    candidate[0, -1] = torch.floor(candidate[0, -1])
    atom_id = atom_ids_with_extra_bond[atom_id_index]
    return candidate, atom_id

def eval_similarity(original_smiles, modified_smiles, similarity_threshold=0.25):
    """
    Determines whether the LSBO process should stop based on molecular similarity.

    Args:
        original_smiles (str): SMILES representation of the original molecule.
        modified_smiles (str): SMILES representation of the modified molecule.
        similarity_threshold (float): The lower limit of similarity to stop the optimization.

    Returns:
        bool: True if the optimization should stop, False otherwise.
    """
    # Convert SMILES to RDKit molecule objects
    original_mol = Chem.MolFromSmiles(original_smiles)
    modified_mol = Chem.MolFromSmiles(modified_smiles)

    # Generate Morgan fingerprints
    fp_original = AllChem.GetMorganFingerprintAsBitVect(original_mol, radius=1, nBits=2048)
    fp_modified = AllChem.GetMorganFingerprintAsBitVect(modified_mol, radius=1, nBits=2048)

    # Calculate Dice similarity
    similarity = DiceSimilarity(fp_original, fp_modified)
    
    # Check if similarity is below the threshold
    if similarity >= similarity_threshold:
        return True, similarity  # Stop the optimization
    else:
        return False, similarity  # Continue the optimization

def calculate_molecular_weight(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        mol_weight = rdMolDescriptors.CalcExactMolWt(mol)
        return mol_weight
    else:
        return "Invalid SMILES string"

def draw_molecule(smiles):
    mol = Chem.MolFromSmiles(smiles)
    Chem.rdDepictor.Compute2DCoords(mol)
    drawer = Draw.rdMolDraw2D.MolDraw2DCairo(300, 300)
    drawer.DrawMolecule(mol)
    drawer.FinishDrawing()
    img_data = io.BytesIO(drawer.GetDrawingText())
    return Image.open(img_data)

def highlight_atoms(smiles, rectangle):
    mol = Chem.MolFromSmiles(smiles)
    Chem.rdDepictor.Compute2DCoords(mol)
    
    drawer = Draw.rdMolDraw2D.MolDraw2DCairo(300, 300)
    drawer.DrawMolecule(mol)
    drawer.FinishDrawing()

    atom_positions = [(drawer.GetDrawCoords(i).x, drawer.GetDrawCoords(i).y) for i in range(mol.GetNumAtoms())]

    x1, y1, x2, y2 = rectangle
    selected_atom = None
    for i, pos in enumerate(atom_positions):
        if x1 <= pos[0] <= x2 and y1 <= pos[1] <= y2:
            selected_atom = i
            break

    drawer = Draw.rdMolDraw2D.MolDraw2DCairo(300, 300)
    drawer.DrawMolecule(mol, highlightAtoms=[selected_atom] if selected_atom is not None else [])
    drawer.FinishDrawing()
    img_data = io.BytesIO(drawer.GetDrawingText())

    return Image.open(img_data), selected_atom


def get_selfie_and_smiles_encodings_for_dataset(file_path):
    """
    Returns encoding, alphabet and length of largest molecule in SMILES and
    SELFIES, given a file containing SMILES molecules.

    input:
        csv file with molecules. Column's name must be 'smiles'.
    output:
        - selfies encoding
        - selfies alphabet
        - longest selfies string
        - smiles encoding (equivalent to file content)
        - smiles alphabet (character based)
        - longest smiles string
    """

    df = pd.read_csv(file_path)

    smiles_list = np.asanyarray(df.smiles)

    smiles_alphabet = list(set(''.join(smiles_list)))
    smiles_alphabet.append(' ')  # for padding

    largest_smiles_len = len(max(smiles_list, key=len))

    print('--> Translating SMILES to SELFIES...')
    selfies_list = list(map(sf.encoder, smiles_list))

    all_selfies_symbols = sf.get_alphabet_from_selfies(selfies_list)
    all_selfies_symbols.add('[nop]')
    selfies_alphabet = list(all_selfies_symbols)

    largest_selfies_len = max(sf.len_selfies(s) for s in selfies_list)

    print('Finished translating SMILES to SELFIES.')

    return selfies_list, selfies_alphabet, largest_selfies_len, \
           smiles_list, smiles_alphabet, largest_smiles_len


def calculate_property(molecule, property_name):
    mol = Chem.MolFromSmiles(molecule)
    if mol is None:
        return None
    if property_name == "LogP":
        return Descriptors.MolLogP(mol)
    elif property_name == "QED":
        return QED.qed(mol)
    else:
        raise ValueError("Invalid property name")


def main():
    st.title("Molecular Design with CLaSMO and Bayesian Optimization")
    st.markdown("""
    This web application allows chemical experts to optimize molecular scaffolds by interactively selecting regions for modification. By combining Conditional Variational Autoencoders (CVAE) with Latent Space Bayesian Optimization (LSBO), the app helps users efficiently discover new molecules with desirable properties while maintaining real-world viability.
                
    To see example usage video, [click here](https://drive.google.com/file/d/1GH2zKhhiVSAhwpoLEkini3jFhs-KbOBY/view?usp=sharing).
    """)
    st.info("This application accepts the SMILES string of a molecule as input. The user first selects the property to optimize, then draws a rectangle around the atom where the substructure will be added. Selection of the region will automatically initiate the optimization process, and the results will be shown below.")

    col1, col2, col3 = st.columns(3)
    with col2:
        #st.warning("anan")
        st.image("img.png", caption="Example selection.", width=150)

    # Initialize session state
    if 'smiles_input' not in st.session_state:
        st.session_state['smiles_input'] = "O=C(Cc1cnoc1)NC1CCCCC1" #"C1=CC=CC=C1"  # Default SMILES (Benzene)

    smiles_input = st.text_input('Enter a SMILES string:', st.session_state['smiles_input'])
    
    if smiles_input:
        mol_image = draw_molecule(smiles_input)
        mol = Chem.MolFromSmiles(smiles_input)

        property_to_optimize = st.selectbox("Select property to optimize:", ["QED", "LogP"])
        input_smiles_property_value = calculate_property(smiles_input, property_to_optimize)
        if input_smiles_property_value is not None:
            st.subheader(f"{property_to_optimize} Value of Input Molecule: {input_smiles_property_value:.2f}")

        st.write("Draw a rectangle on the molecule to select an atom. To remove the selection, click the trash bin icon.")
        st.info("CLaSMO will run for only 10 iterations.")
        if 'canvas_key' not in st.session_state:
            st.session_state['canvas_key'] = 0  # Used to force canvas reinitialization

        canvas_result = st_canvas(
            fill_color="rgba(255, 165, 0, 0.3)",
            stroke_width=2,
            stroke_color="#e00",
            background_image=mol_image,
            update_streamlit=True,
            width=300,
            height=300,
            drawing_mode="rect",
            key=f"canvas_{st.session_state['canvas_key']}"
            #key="canvas",
        )

        if canvas_result.json_data is not None and "objects" in canvas_result.json_data:
            rectangles = [obj for obj in canvas_result.json_data["objects"] if obj["type"] == "rect"]
            if rectangles:
                rect = rectangles[0]
                x1, y1 = rect["left"], rect["top"]
                x2, y2 = rect["left"] + rect["width"], rect["top"] + rect["height"]
                highlighted_image, selected_atom_index = highlight_atoms(smiles_input, (x1, y1, x2, y2))
                
                col1, col2 = st.columns(2)
                with col1:
                    st.image(highlighted_image, caption='Highlighted Molecule', width=300)
                
                if selected_atom_index is not None:
                    st.write(f"Selected Atom Index: {selected_atom_index}")

                    if os.path.exists("settings_cvae.yml"):
                        settings = yaml.safe_load(open("settings_cvae.yml", "r"))
                    else:
                        print("Expected a file settings.yml but didn't find it.")
                        return

                    print('--> Acquiring data...')
                    file_name_smiles = settings['data']['smiles_file']

                    print('Finished acquiring data.')

                    print('Representation: SELFIES')
                    encoding_list, _, largest_molecule_len, _, _, _ = \
                            get_selfie_and_smiles_encodings_for_dataset(file_name_smiles)
                    encoding_alphabet = ['[=N]', '[Branch1]', '[#N]', '[=Ring1]', '[=C]', '[=Branch1]', '[NH1]', '[#Branch1]', '[Ring2]', '[=Branch2]', '[nop]', '[C]', '[=O]', '[N]', '[#C]', '[Ring1]', '[F]', '[O]']

                    print('--> Creating one-hot encoding...')
                    data = multiple_selfies_to_hot(encoding_list, largest_molecule_len,
                                                    encoding_alphabet)
                    print('Finished creating one-hot encoding.')

                    len_max_molec = data.shape[1]
                    len_alphabet = data.shape[2]
                    len_max_mol_one_hot = len_max_molec * len_alphabet

                    print(' ')
                    print(f"Alphabet has {len_alphabet} letters, "
                        f"largest molecule is {len_max_molec} letters.")

                    encoder_parameter = settings['encoder']
                    encoder_parameter['latent_dimension'] = z_dim
                    decoder_parameter = settings['decoder']
                    decoder_parameter['latent_dimension'] = z_dim
                    training_parameters = settings['training']
                    training_parameters['KLD_alpha'] = beta

                    ### LOAD MODEL ###
                    vae_encoder = VAEEncoder(condition_embedding_dim=training_parameters['condition_embedding_dim'],in_dimension=len_max_mol_one_hot,
                                            **encoder_parameter).to(device)
                    vae_decoder = VAEDecoder(**decoder_parameter,condition_embedding_dim=training_parameters['condition_embedding_dim'],
                                            out_dimension=len(encoding_alphabet)).to(device)
                    vae_encoder.load_state_dict(torch.load(f'clasmo_inputs/E_ld_{z_dim}_beta_{beta}.pt', map_location=torch.device(device)))
                    vae_decoder.load_state_dict(torch.load(f'clasmo_inputs/D_ld_{z_dim}_beta_{beta}.pt', map_location=torch.device(device)))
                    input_dim = 6 # Number of features
                    z_n = 2 # Latent dimension size

                    ### LOAD CONDITION EMBEDDING MODEL ###
                    encoder = VAE_Encoder_Fully_connected_unit(z_n, input_dim).to(device)
                    encoder.load_state_dict(torch.load('clasmo_inputs/embeddings_encoder.pt', map_location=torch.device(device)))
                    scaler = joblib.load('clasmo_inputs/minmaxscaler.joblib')
                    df = pd.read_csv("clasmo_input_data.csv") # input scaffolds
                    scaffolds = df['scaffold']
                    results = []
                    seeds = [0]
                    ### LOAD GP TRAINING DATA ###
                    sorted_results_df = pd.read_csv(f'clasmo_inputs/gp_cvae_y_sorted_ld_{z_dim}_beta_{beta}_target_property_{target_property}.csv')
                    atom_ids_for_gp = pd.read_csv(f'clasmo_inputs/gp_cvae_selected_atom_ids_ld_{z_dim}_beta_{beta}_target_property_{target_property}.csv') 
                    atom_ids_np = atom_ids_for_gp.iloc[:, 1].to_numpy()
                    atom_ids_tensor = torch.tensor(atom_ids_np, dtype=torch.int).view(-1, 1).to(device)
                    selected_latent_points = torch.load(f"clasmo_inputs/gp_cvae_x_selected_ld_{z_dim}_beta_{beta}_target_property_{target_property}.pt", map_location=torch.device(device))
                    initial_X = torch.cat((selected_latent_points, atom_ids_tensor), dim=1)
                    initial_Y = sorted_results_df['reward'].values


                    # Initialize GP model
                    train_X = torch.tensor(initial_X, dtype=torch.float64).to(device)
                    train_Y = torch.tensor(initial_Y, dtype=torch.float64).view(-1, 1).to(device)
                    categorical_cols = [train_X.shape[1] - 1]
                    categorical_kernel = CategoricalKernel(num_tasks=1)
                    scale_cat_kernel = ScaleKernel(categorical_kernel)
                    continuous_kernel = RBFKernel()
                    scale_cont_kernel = ScaleKernel(continuous_kernel)
                    product_kernel = ProductKernel(scale_cont_kernel, scale_cat_kernel)
                    gp_model = MixedSingleTaskGP(train_X, train_Y, cat_dims=categorical_cols)
                    gp_model.covar_module = product_kernel
                    mll = ExactMarginalLogLikelihood(gp_model.likelihood.to(device), gp_model.to(device))
                    fit_gpytorch_model(mll)

                    # Define bounds
                    continuous_bounds = [[-6.0, 6.0]] * z_dim
                    mol = Chem.MolFromSmiles(smiles_input)
                    atom_ids_with_extra_bond = find_atoms_with_extra_bond(mol)
                    torch.manual_seed(0)
                    np.random.seed(0)
                    random.seed(0)
                    torch.use_deterministic_algorithms(True)
                    if can_have_extra_bond(mol, selected_atom_index) == False: #selected_atom_index not in atom_ids_with_extra_bond:
                        st.warning("Selected atom cannot form additional bonds.")
                        return
                    atom_ids_with_extra_bond = [selected_atom_index]
                    discrete_bounds = [[0, (len(atom_ids_with_extra_bond)-1)]]
                    bounds = torch.tensor(continuous_bounds + discrete_bounds).T

                    # Start Bayesian Optimization
                    trial = 0
                    max_trials = 10
                    best_smiles = None
                    init_best = input_smiles_property_value
                    best_property_value = init_best
                    iteration_results = []

                    while trial < max_trials:
                        trial += 1
                        UCB_acq = UpperConfidenceBound(gp_model, beta=2.5)
                        candidate, atom_id = optimize_and_map(bounds, UCB_acq, atom_ids_with_extra_bond)
                        candidate_latent = candidate[:, :-1]

                        # Get atom features
                        atom = mol.GetAtomWithIdx(atom_id)
                        atom_features = get_atom_features(atom)
                        atom_features = list(atom_features.values())
                        feature_list = [atom_features]
                        scaled_features = scaler.transform(feature_list)
                        scaled_features = torch.tensor(scaled_features[0], dtype=torch.float).to(device)
                        atom_features_embedded, _, _ = encoder(scaled_features.unsqueeze(0))

                        # Generate substructure
                        gathered_atoms = sample_latent_space(
                            vae_encoder, vae_decoder, atom_features_embedded, candidate_latent, len_max_molec
                        )
                        molecule_pre = ''
                        for i in gathered_atoms:
                            molecule_pre += encoding_alphabet[i]
                        molecule = molecule_pre.replace(' ', '')
                        molecule = sf.decoder(molecule)

                        if is_correct_smiles(molecule):
                            modified_smiles = add_substructure(smiles_input, atom_id, molecule)
                            if is_correct_smiles(modified_smiles) and modified_smiles != smiles_input:
                                similarity_check, similarity = eval_similarity(smiles_input, modified_smiles, similarity_threshold=0.25)
                                if similarity_check:
                                    modified_property_value = calculate_property(modified_smiles, property_to_optimize)
                                    eval_reward = modified_property_value - best_property_value
                                    if eval_reward > 0:
                                    # Record the iteration and increase
                                        iteration_results.append({
                                            'Iteration': trial,
                                            'Increase in Property': eval_reward,
                                            'Modified SMILES': modified_smiles,
                                            'Modifid SMILES QED': modified_property_value
                                        })

                                    if modified_property_value > init_best:
                                        init_best = modified_property_value
                                        best_smiles = modified_smiles
                                        best_property_value = modified_property_value
                                else:
                                    eval_reward = -5
                            else:
                                eval_reward = -7.5
                        else:
                            eval_reward = -10

                        # Update GP model
                        train_X = torch.cat([train_X, candidate.to(device)])
                        train_Y = torch.cat([train_Y, torch.tensor([[eval_reward]]).to(device)])
                        gp_model = MixedSingleTaskGP(train_X, train_Y, cat_dims=categorical_cols)
                        gp_model.covar_module = product_kernel
                        mll = ExactMarginalLogLikelihood(gp_model.likelihood.to(device), gp_model.to(device))
                        fit_gpytorch_model(mll)

                    # After optimization loop
                    if best_smiles:
                        with col2:
                            st.image(draw_molecule(best_smiles), caption='Optimized Molecule', width=300)
                        st.subheader(f"Optimized {property_to_optimize} Value: {best_property_value:.2f}")
                        st.session_state['optimized_smiles'] = best_smiles
                        # Display the table of results
                        results_df = pd.DataFrame(iteration_results)
                        st.write("**Iteration Results:**")
                        st.table(results_df)
                    else:
                        st.warning("Could not find a better molecule after optimization.")
                    if 'optimized_smiles' in st.session_state:
                        if st.button('Use Optimized Molecule for Further Optimization'):
                            # Update the default SMILES in the text input field
                            st.session_state['smiles_input'] = st.session_state['optimized_smiles']
                            #canvas_result.json_data = None
                            st.session_state['canvas_key'] += 1
                            st.experimental_rerun()  # Rerun the app with updated smiles_input

if __name__ == "__main__":
    # Load the necessary data and variables
    # Assuming you have saved the required data structures using joblib
    # encoding_list, encoding_alphabet, largest_molecule_len, data
    device = torch.device("cpu")
    main()