#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
CLaSMO optimization code for QED task.

This code loads the trained CVAE model, and updates input scaffolds 
in the clasmo_input_data.csv file using the proposed CLaSMO approach.

"""

import os
import sys
import joblib
import numpy as np
import pandas as pd
import torch
import yaml
from rdkit import rdBase
from rdkit.Chem import MolFromSmiles
from rdkit.Chem import Draw
from rdkit import Chem
from torch import nn
from rdkit.Chem import Descriptors, AllChem
from rdkit.Chem import QED
from rdkit.Chem.Scaffolds import MurckoScaffold
import math
import random
import selfies as sf
from data_loader import \
    multiple_selfies_to_hot, multiple_smile_to_hot
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import UpperConfidenceBound
from botorch.optim import optimize_acqf
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from gpytorch.kernels import ScaleKernel, RBFKernel, ProductKernel
from botorch.models.kernels.categorical import CategoricalKernel
import os
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem import AllChem
from rdkit.DataStructs import DiceSimilarity

import warnings

# Suppress all warnings
warnings.filterwarnings('ignore')

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

beta = 0.000001
z_dim = 2
target_property = "qed"

rdBase.DisableLog('rdApp.error')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the VAE encoder and decoder as provided
class VAE_Encoder_Fully_connected_unit(nn.Module):
    def __init__(self, z_n, input_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(input_dim, 4),
            nn.Sigmoid(),
        )
        self.fc_mean = nn.Linear(4, z_n)
        self.fc_log_var = nn.Linear(4, z_n)

    def forward(self, x):
        x = self.fc(x)
        mean = self.fc_mean(x)
        log_var = self.fc_log_var(x)
        z = self.reparameterize(mean, log_var)
        return z, mean, log_var

    @staticmethod
    def reparameterize(mean, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mean)

class ConditionalBatchNorm1d(nn.Module):
    def __init__(self, num_features, num_conditions):
        super(ConditionalBatchNorm1d, self).__init__()
        self.num_features = num_features
        self.bn = nn.BatchNorm1d(num_features, affine=False)
        self.gamma_embed = nn.Linear(num_conditions, num_features)
        self.beta_embed = nn.Linear(num_conditions, num_features)

    def forward(self, x, condition):
        out = self.bn(x)        
        # Generate gamma and beta from the condition vector
        gamma = self.gamma_embed(condition)  # Shape: [batch_size, num_features]
        beta = self.beta_embed(condition)    # Shape: [batch_size, num_features]
        
        # Reshape gamma and beta to match the shape of out
        gamma = gamma.view(-1, self.num_features)
        beta = beta.view(-1, self.num_features)
        
        # Apply conditional batch normalization
        out = gamma * out + beta
        return out


class VAEEncoder(nn.Module):
    def __init__(self, condition_embedding_dim, in_dimension, layer_1d, layer_2d, layer_3d, latent_dimension):
        super(VAEEncoder, self).__init__()
        self.latent_dimension = latent_dimension

        self.fc1 = nn.Linear(in_dimension + condition_embedding_dim, layer_1d)
        self.cbn1 = ConditionalBatchNorm1d(layer_1d, condition_embedding_dim)
        
        self.fc2 = nn.Linear(layer_1d, layer_2d)
        self.cbn2 = ConditionalBatchNorm1d(layer_2d, condition_embedding_dim)
        
        self.fc3 = nn.Linear(layer_2d, layer_3d)
        self.cbn3 = ConditionalBatchNorm1d(layer_3d, condition_embedding_dim)
        
        self.encode_mu = nn.Linear(layer_3d, latent_dimension)
        self.encode_log_var = nn.Linear(layer_3d, latent_dimension)

    @staticmethod
    def reparameterize(mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, x, condition):
        condition = condition.view(x.size(0), -1)
        combined_input = torch.cat((x, condition), dim=1)
        
        x = torch.nn.functional.relu(self.cbn1(self.fc1(combined_input), condition))
        x = torch.nn.functional.relu(self.cbn2(self.fc2(x), condition))
        x = torch.nn.functional.relu(self.cbn3(self.fc3(x), condition))
        
        mu = self.encode_mu(x)
        log_var = self.encode_log_var(x)
        z = self.reparameterize(mu, log_var)
        return z, mu, log_var


class VAEDecoder(nn.Module):
    def __init__(self, latent_dimension, gru_stack_size, gru_neurons_num, condition_embedding_dim, out_dimension):
        super(VAEDecoder, self).__init__()
        self.latent_dimension = latent_dimension + condition_embedding_dim
        self.gru_stack_size = gru_stack_size
        self.gru_neurons_num = gru_neurons_num

        self.decode_RNN = nn.GRU(
            input_size=self.latent_dimension,
            hidden_size=gru_neurons_num,
            num_layers=gru_stack_size,
            batch_first=False)
        
        self.decode_cbn = ConditionalBatchNorm1d(gru_neurons_num, condition_embedding_dim)
        self.decode_FC = nn.Linear(gru_neurons_num, out_dimension)

    def init_hidden(self, batch_size=1):
        weight = next(self.parameters())
        return weight.new_zeros(self.gru_stack_size, batch_size, self.gru_neurons_num)

    def forward(self, z, hidden, condition):
        l1, hidden = self.decode_RNN(z, hidden)
        
        # Apply Conditional Batch Normalization
        condition = condition.view(z.size(1), -1)
        l1 = self.decode_cbn(l1, condition.unsqueeze(0).expand(l1.size(0), -1, -1))
        
        decoded = self.decode_FC(l1)
        return decoded, hidden


def is_correct_smiles(smiles):
    """
    Using RDKit to calculate whether molecule is syntactically and
    semantically valid.
    """
    if smiles == "":
        return False

    try:
        return MolFromSmiles(smiles, sanitize=True) is not None
    except Exception:
        return False


def sample_latent_space(vae_encoder, vae_decoder, condition_vector, candidate, sample_len):
    # generate from given latent vector and condition vector
    vae_encoder.eval()
    vae_decoder.eval()

    gathered_atoms = []
    combined_input = torch.cat((candidate.to(device), condition_vector.to(device)), dim=1).unsqueeze(0)
    hidden = vae_decoder.init_hidden(batch_size=100)

    # runs over letters from molecules (len=size of largest molecule)
    for _ in range(sample_len):

        combined_input_rpt = combined_input.repeat(1, 100, 1)

        # Convert tensor2 from shape [1, 2] to [100, 2]
        condition_vector_rpt = condition_vector.repeat(100, 1)
        out_one_hot, hidden = vae_decoder(combined_input_rpt, hidden, condition_vector_rpt)        
        out_one_hot = out_one_hot.flatten().detach()
        out_one_hot = out_one_hot[:18]
        soft = nn.Softmax(dim=0)
        probabilities = soft(out_one_hot)
        chosen_atom_index = torch.argmax(probabilities).item()
        gathered_atoms.append(chosen_atom_index)
    vae_encoder.train()
    vae_decoder.train()

    return gathered_atoms

def get_selfie_and_smiles_encodings_for_dataset(file_path):
    """
    Returns encoding, alphabet and length of largest molecule in SMILES and
    SELFIES, given a file containing SMILES molecules.

    input:
        csv file with molecules. Column's name must be 'smiles'.
    output:
        - selfies encoding
        - selfies alphabet
        - longest selfies string
        - smiles encoding (equivalent to file content)
        - smiles alphabet (character based)
        - longest smiles string
    """

    df = pd.read_csv(file_path)

    smiles_list = np.asanyarray(df.smiles)

    smiles_alphabet = list(set(''.join(smiles_list)))
    smiles_alphabet.append(' ')  # for padding

    largest_smiles_len = len(max(smiles_list, key=len))

    print('--> Translating SMILES to SELFIES...')
    selfies_list = list(map(sf.encoder, smiles_list))

    all_selfies_symbols = sf.get_alphabet_from_selfies(selfies_list)
    all_selfies_symbols.add('[nop]')
    selfies_alphabet = list(all_selfies_symbols)

    largest_selfies_len = max(sf.len_selfies(s) for s in selfies_list)

    print('Finished translating SMILES to SELFIES.')

    return selfies_list, selfies_alphabet, largest_selfies_len, \
           smiles_list, smiles_alphabet, largest_smiles_len

def add_substructure(input_smiles, atom_index, substructure_smiles):
    # add generated substructre to input molecule at the given atom index
    try:
        mol = Chem.MolFromSmiles(input_smiles)
        if mol is None:
            raise ValueError("Invalid input SMILES")
        
        substructure_mol = Chem.MolFromSmiles(substructure_smiles)
        if substructure_mol is None:
            raise ValueError("Invalid substructure SMILES")
        
        edited_mol = Chem.RWMol(mol)
        combo = Chem.CombineMols(edited_mol, substructure_mol)
        combo_mol = Chem.RWMol(combo)
        
        # Calculate the new atom index in the combined molecule
        new_atom_index = edited_mol.GetNumAtoms()
        
        combo_mol.AddBond(atom_index, new_atom_index, Chem.BondType.SINGLE)
        Chem.SanitizeMol(combo_mol)
        return Chem.MolToSmiles(combo_mol)
    except:
        return input_smiles

def get_atom_features(atom):
    #Extract features from a single atom and return them as a dictionary.
    return {
        'AtomicNum': atom.GetAtomicNum(),
        'Hybridization': int(atom.GetHybridization()),
        'Valence': atom.GetTotalValence(),
        'FormalCharge': atom.GetFormalCharge(),
        'Degree': atom.GetDegree(),
        'IsInRing': int(atom.IsInRing())
    }

def find_atoms_with_extra_bond(mol):
    # find atoms that have remaining valence (can form a new bond)
    atom_ids = []
    for atom in mol.GetAtoms():
        if can_have_extra_bond(mol, atom.GetIdx()):
            atom_ids.append(atom.GetIdx())
    return atom_ids

def can_have_extra_bond(mol, atom_idx):
    # check if selected atom in the molecule can form another bond
    atom = mol.GetAtomWithIdx(atom_idx)
    degree = atom.GetDegree()
    total_valence = atom.GetTotalValence()
    return degree < total_valence

def optimize_and_map(bounds, UCB_acq, atom_ids_with_extra_bond):
    # Optimization
    candidate, _ = optimize_acqf(
        acq_function=UCB_acq,
        bounds=bounds,
        q=1,
        num_restarts=10,
        raw_samples=500,
    )
    # Convert the last element from index to discrete atom ID
    atom_id_index = int(math.floor(candidate[0, -1].item()))
    candidate[0, -1] = torch.floor(candidate[0, -1])
    atom_id = atom_ids_with_extra_bond[atom_id_index]  # Map to actual atom ID

    return candidate, atom_id

def get_scaffold(smiles):
    # Convert the SMILES string to an RDKit molecule object
    mol = Chem.MolFromSmiles(smiles)
    
    if not mol:
        return "Invalid SMILES string"
    
    # Get the Bemis-Murcko scaffold for the molecule
    scaffold = MurckoScaffold.GetScaffoldForMol(mol)
    
    # Convert the scaffold to a SMILES string
    scaffold_smiles = Chem.MolToSmiles(scaffold)
    return scaffold_smiles

def eval_similarity(original_smiles, modified_smiles, similarity_threshold=0.7):
    """
    Determines whether the LSBO process should stop based on molecular similarity.

    Args:
        original_smiles (str): SMILES representation of the original molecule.
        modified_smiles (str): SMILES representation of the modified molecule.
        similarity_threshold (float): The lower limit of similarity to stop the optimization.

    Returns:
        bool: True if the optimization should stop, False otherwise.
    """
    # Convert SMILES to RDKit molecule objects
    original_mol = Chem.MolFromSmiles(original_smiles)
    modified_mol = Chem.MolFromSmiles(modified_smiles)

    # Generate Morgan fingerprints
    fp_original = AllChem.GetMorganFingerprintAsBitVect(original_mol, radius=1, nBits=2048)
    fp_modified = AllChem.GetMorganFingerprintAsBitVect(modified_mol, radius=1, nBits=2048)

    # Calculate Dice similarity
    similarity = DiceSimilarity(fp_original, fp_modified)
    
    # Check if similarity is below the threshold
    if similarity >= similarity_threshold:
        return True, similarity  # Stop the optimization
    else:
        return False, similarity  # Continue the optimization

def calculate_molecular_weight(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        mol_weight = rdMolDescriptors.CalcExactMolWt(mol)
        return mol_weight
    else:
        return "Invalid SMILES string"

def main():

    if os.path.exists("settings_cvae.yml"):
        settings = yaml.safe_load(open("settings_cvae.yml", "r"))
    else:
        print("Expected a file settings.yml but didn't find it.")
        return

    print('--> Acquiring data...')
    file_name_smiles = settings['data']['smiles_file']

    print('Finished acquiring data.')

    print('Representation: SELFIES')
    encoding_list, _, largest_molecule_len, _, _, _ = \
            get_selfie_and_smiles_encodings_for_dataset(file_name_smiles)
    encoding_alphabet = ['[=N]', '[Branch1]', '[#N]', '[=Ring1]', '[=C]', '[=Branch1]', '[NH1]', '[#Branch1]', '[Ring2]', '[=Branch2]', '[nop]', '[C]', '[=O]', '[N]', '[#C]', '[Ring1]', '[F]', '[O]']

    print('--> Creating one-hot encoding...')
    data = multiple_selfies_to_hot(encoding_list, largest_molecule_len,
                                       encoding_alphabet)
    print('Finished creating one-hot encoding.')

    len_max_molec = data.shape[1]
    len_alphabet = data.shape[2]
    len_max_mol_one_hot = len_max_molec * len_alphabet

    print(' ')
    print(f"Alphabet has {len_alphabet} letters, "
          f"largest molecule is {len_max_molec} letters.")

    encoder_parameter = settings['encoder']
    encoder_parameter['latent_dimension'] = z_dim
    decoder_parameter = settings['decoder']
    decoder_parameter['latent_dimension'] = z_dim
    training_parameters = settings['training']
    training_parameters['KLD_alpha'] = beta

    ### LOAD MODEL ###
    vae_encoder = VAEEncoder(condition_embedding_dim=training_parameters['condition_embedding_dim'],in_dimension=len_max_mol_one_hot,
                             **encoder_parameter).to(device)
    vae_decoder = VAEDecoder(**decoder_parameter,condition_embedding_dim=training_parameters['condition_embedding_dim'],
                             out_dimension=len(encoding_alphabet)).to(device)
    vae_encoder.load_state_dict(torch.load(f'clasmo_inputs/E_ld_{z_dim}_beta_{beta}.pt', map_location=torch.device(device)))
    vae_decoder.load_state_dict(torch.load(f'clasmo_inputs/D_ld_{z_dim}_beta_{beta}.pt', map_location=torch.device(device)))
    input_dim = 6 # Number of features
    z_n = 2 # Latent dimension size

    ### LOAD CONDITION EMBEDDING MODEL ###
    encoder = VAE_Encoder_Fully_connected_unit(z_n, input_dim).to(device)
    encoder.load_state_dict(torch.load('clasmo_inputs/embeddings_encoder.pt', map_location=torch.device(device)))
    scaler = joblib.load('clasmo_inputs/minmaxscaler.joblib')
    df = pd.read_csv("clasmo_input_data.csv") # input scaffolds
    scaffolds = df['scaffold']
    results = []
    seeds = [0]
    ### LOAD GP TRAINING DATA ###
    sorted_results_df = pd.read_csv(f'clasmo_inputs/gp_cvae_y_sorted_ld_{z_dim}_beta_{beta}_target_property_{target_property}.csv')
    atom_ids_for_gp = pd.read_csv(f'clasmo_inputs/gp_cvae_selected_atom_ids_ld_{z_dim}_beta_{beta}_target_property_{target_property}.csv') 
    atom_ids_np = atom_ids_for_gp.iloc[:, 1].to_numpy()
    atom_ids_tensor = torch.tensor(atom_ids_np, dtype=torch.int).view(-1, 1).to(device)
    selected_latent_points = torch.load(f"clasmo_inputs/gp_cvae_x_selected_ld_{z_dim}_beta_{beta}_target_property_{target_property}.pt", map_location=torch.device(device))
    initial_X = torch.cat((selected_latent_points, atom_ids_tensor), dim=1)
    initial_Y = sorted_results_df['reward'].values

    ## iterate through input scaffolds
    input_molecule_counter = 0
    for scaffold_input in scaffolds:
        print(f"CLaSMO is running for input scaffold number {input_molecule_counter}")
        original_scaffold = scaffold_input
        for seed in seeds:
            resulting_smiles = original_scaffold
            scaffold_input = original_scaffold
            init_qed = QED.qed(Chem.MolFromSmiles(scaffold_input))
            scaffold_smiles_beginning = original_scaffold
            atom_ids_with_extra_bond = find_atoms_with_extra_bond(Chem.MolFromSmiles(scaffold_input))
            torch.manual_seed(seed)
            np.random.seed(seed)
            random.seed(seed)
            torch.use_deterministic_algorithms(True)
            train_X = torch.tensor(initial_X, dtype=torch.float64).to(device)
            train_Y = torch.tensor(initial_Y, dtype=torch.float64).view(-1, 1).to(device)
            categorical_cols = [train_X.shape[1] - 1]  # Assuming the last column is categorical
            categorical_kernel = CategoricalKernel(num_tasks=1)
            scale_cat_kernel = ScaleKernel(categorical_kernel)
            continuous_kernel = RBFKernel()
            scale_cont_kernel = ScaleKernel(continuous_kernel)
            # Combine the kernels
            product_kernel = ProductKernel(scale_cont_kernel, scale_cat_kernel)
            gp_model = MixedSingleTaskGP(train_X, train_Y, cat_dims=categorical_cols)
            gp_model.covar_module = product_kernel
            mll = ExactMarginalLogLikelihood(gp_model.likelihood.to(device), gp_model.to(device))
            fit_gpytorch_model(mll)
            continuous_bounds = [[-6.0, 6.0]] * z_dim
            discrete_bounds = [[0, (len(atom_ids_with_extra_bond)-1)]]
            bounds = torch.tensor(continuous_bounds + discrete_bounds).T
            for j in range(0, 100):
                print(f"**** LSBO STEP {j} ****")
                UCB_acq = UpperConfidenceBound(gp_model, beta=2.5)
                candidate, atom_id = optimize_and_map(bounds, UCB_acq, atom_ids_with_extra_bond) #CLaSMO finds the latent vector and atom id
                mol = Chem.MolFromSmiles(scaffold_input)
                atom = mol.GetAtomWithIdx(atom_id)
                atom_features = get_atom_features(atom) # get 6-dimensional condition vector
                atom_features = list(atom_features.values())
                feature_list = [atom_features]  # Wrap list in another list to make it 2D
                scaled_features = scaler.transform(feature_list)
                scaled_features = torch.tensor(scaled_features[0], dtype=torch.float).to(device)
                atom_features_embedded, _, _ = encoder(scaled_features.unsqueeze(0)) # encode 6-dimensional condition vector to 2 dimension
                molecule_pre = ''
                ## generate substructure
                gathered_atoms = sample_latent_space(vae_encoder, vae_decoder, atom_features_embedded, candidate[:, :-1], len_max_molec)
                for i in gathered_atoms:
                    molecule_pre += encoding_alphabet[i]
                molecule = molecule_pre.replace(' ', '')
                molecule = sf.decoder(molecule)

                if is_correct_smiles(molecule):
                    resulting_smiles = add_substructure(scaffold_input, atom_id, molecule)
                    if resulting_smiles != scaffold_input: # check if substructure successfully added
                        _, similarity = eval_similarity(resulting_smiles, scaffold_smiles_beginning, 0.25) # in this example, similarity threshold is set to 0.25
                        mwr = calculate_molecular_weight(resulting_smiles)
                        mwo = calculate_molecular_weight(scaffold_smiles_beginning)
                        mw_delta = (mwr - mwo)/mwo
                        res_qed = QED.qed(Chem.MolFromSmiles(resulting_smiles))
                        if similarity >= 0.25 and similarity < 1: # in this example, similarity threshold is set to 0.25
                            eval = res_qed - QED.qed(Chem.MolFromSmiles(scaffold_input))
                            if eval > 0:
                                print(f'y_delta is {eval} at CLaSMO step {j} for input molecule {input_molecule_counter}, QED is improved to {res_qed} from {init_qed}.')
                                results.append({'Atom_ID': atom_id, 'opt_step':j, 
                                                'Input_SMILES': scaffold_input,
                                                'Input_SMILES_weight':mwo, 
                                                'Resulting_SMILES': resulting_smiles,
                                                'Resulting_SMILES_weight':mwr, 
                                                'reward': eval, 
                                                'similarity':similarity,  
                                                'mw_delta': mw_delta, 
                                                'Input_Scaffold_QED': init_qed,
                                                'Resulting_QED':res_qed})
                                scaffold_input = resulting_smiles
                                atom_ids_with_extra_bond = find_atoms_with_extra_bond(Chem.MolFromSmiles(scaffold_input))
                                discrete_bounds = [[0, (len(atom_ids_with_extra_bond)-1)]]
                                bounds = torch.tensor(continuous_bounds + discrete_bounds).T
                        else:
                            eval = -5
                    else:
                        eval = -7.5
                else:
                    eval = -10
                train_X = torch.cat([train_X, candidate.to(device)])
                train_Y = torch.cat([train_Y, torch.tensor([[eval]]).to(device)])
                gp_model = SingleTaskGP(train_X, train_Y)
                mll = ExactMarginalLogLikelihood(gp_model.likelihood, gp_model)
                fit_gpytorch_model(mll) 
                results_df = pd.DataFrame(results) 
                results_df.to_csv(f"clasmo_results_new_run.csv")
                
        
        input_molecule_counter+=1

if __name__ == '__main__':
    try:
        main()
    except AttributeError:
        _, error_message, _ = sys.exc_info()
        print(error_message)
