import hydra
import os
import sys
import torch
from unittest.mock import MagicMock

from syntheseus import Molecule

from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import get_batch
from multiguide.syntheseus.single_step_models.root_aligned_fixed import RootAlignedFixedModel
from multiguide.dataset.helpers import turn_results_to_mol_smiles
from multiguide.dataset.helpers import get_vocab_from_trained_model
from multiguide.property.property_predictor import PropertyPredictor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# def safe_import_helpers():
#     # Mock the problematic modules first
#     sys.modules['rdkit.Chem.Draw.rdMolDraw2D'] = MagicMock()
#     sys.modules['rdkit.Chem.Draw'] = MagicMock()
    
#     from multiguide.evaluation.helpers import get_retrosynthetic_results
#     return get_retrosynthetic_results

# # Use the safe import
# get_retrosynthetic_results = safe_import_helpers()

@hydra.main(config_path='../configs', config_name='config.yaml')
def main(config):
    
    vocab = get_vocab_from_trained_model(config.classifier_guidance.onmt_checkpoint_path)
    # TODO: override this
    #alphabet_size = alphabet_size + 4
    print(f'======= alphabet size: {len(vocab)}')
    seq = torch.ones((1,2)).long().to(device)
    property_model = PropertyPredictor(config, len(vocab))
    checkpoint_path = os.path.join(PROJECT_ROOT,
                                    'experiments', 
                                    config.classifier_guidance.experiment_name,
                                    'checkpoints',
                                    config.classifier_guidance.checkpoint_path)
    print(f'======= loading property checkpoint from {checkpoint_path}')
    property_checkpoint = torch.load(checkpoint_path, map_location=device)
    property_model.load_state_dict(property_checkpoint['model_state_dict'])
    property_model = property_model.to(device)
    out = property_model(seq)
    exit()
    # all_rows = []
    # batch = get_batch(config)
    # # route_idx 2, last reaction
    # # old starting 'COC(=O)c1ccc(Cl)cc1N'
    # for i, (reactants, product, class_idx, conditional_starting_material, conditional_target) in enumerate(batch):
    #     print(f'Idx {i+int(config.single_step_evaluation.start_idx)}')
    #     config.single_step_evaluation.product_smi = product
    #     config.single_step_evaluation.true_reactants = reactants

        # reactant_predictions = get_retrosynthetic_results(config, config.single_step_evaluation.product_smi,
        #                                                 conditional_starting_material=conditional_starting_material,
        #                                                 conditional_target=conditional_target)

        # mol = Molecule(product)
        # retrosynthetic_model_dir = os.path.join(PROJECT_ROOT, 'checkpoints',  config.single_step_model.model_dir)
        # model = RootAlignedFixedModel(use_cache=True,
        #                                 num_augmentations=config.single_step_model.num_augmentations,
        #                                 default_num_results=config.single_step_model.default_num_results, # 10
        #                                 model_dir=retrosynthetic_model_dir,
        #                                 config=config,
        #                                 conditional_starting_material=conditional_starting_material,
        #                                 conditional_target=conditional_target)
        # results = model([mol], num_results=config.single_step_model.default_num_results)
        # results_smiles = turn_results_to_mol_smiles(results)
        # print(results_smiles)
                                                        
if __name__ == '__main__':
    main()