from ast import GtE
from unittest.mock import MagicMock
import sys
import os
import warnings
import time
import logging
from pathlib import Path
import json
import hydra
import torch
from rdkit import Chem
from rdkit import RDLogger
sys.path.insert(0, str(Path(__file__).parent.parent))
from multiguide.helpers import PROJECT_ROOT, set_seed
from multiguide.dataset.helpers import get_smiles_list, get_targets_and_reaction_type_from_routes
from multiguide.syntheseus.helpers import run_for_one_mol_desp, run_for_one_mol_retro_star, set_search_algorithm
from multiguide.evaluation.helpers import define_single_step_model

# Mock the problematic RDKit drawing modules
sys.modules['rdkit.Chem.Draw.rdMolDraw2D'] = MagicMock()
sys.modules['rdkit.Chem.Draw'] = MagicMock()
os.environ["PYTHONWARNINGS"] = "ignore::UserWarning"
warnings.filterwarnings("ignore", category=FutureWarning, module="onmt")
# Configure the root logger
logging.basicConfig(
    level=logging.DEBUG - 1,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Try the correct method for your RDKit version
RDLogger.DisableLog('rdApp.*')  # This should disable all RDKit logging
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def set_node_depth(target_idx):
    '''
    Set the node depth for the starting material.
    '''
    true_routes_path = os.path.join(
        PROJECT_ROOT,
        'data', 
        'uspto_190', 
        'in_json', 
        'test_with_tanimoto_weight1.json'
    )
    with open(true_routes_path, 'r', encoding="utf-8") as f:
        true_routes = json.load(f)
    # logic for depth: tree is made of interleaved or and and nodes and is 0 indexed.
    # We want to enforce the starting material when expanding the level before last
    depth = (len(true_routes[target_idx]['route'])-1)*2
    return depth

# Some constants for all algorithms
# RXN_MODEL_CALL_LIMIT = 100 # 100
# TIME_LIMIT_S = 600 # 300

def patched_canonicalize(smiles, return_max_frag=True, synthon=None):
    #print(f'========= in patched_canonicalize')
    print(f'smiles {smiles}')
    if '<' in smiles or '>' in smiles:
        print(f'Found < or > in smiles: {smiles}')
    mol = Chem.MolFromSmiles(smiles,sanitize=not synthon)
    # # Remove dative bonds by converting them back to single bonds
    if mol is not None:
        for bond in mol.GetBonds():
            if bond.GetBondType() == Chem.BondType.DATIVE:
                bond.SetBondType(Chem.BondType.SINGLE)
        [atom.ClearProp('molAtomMapNumber') for atom in mol.GetAtoms() if atom.HasProp('molAtomMapNumber')]
        try:
            smi = Chem.MolToSmiles(mol, isomericSmiles=True)
            #print(f'========= returning smi: {smi}')
            if '<' in smiles or '>' in smiles:
                raise ValueError(f'Found < or > in smiles: {smiles}')
        except:
            if return_max_frag:
                return '',''
            else:
                return ''
        if return_max_frag:
            sub_smi = smi.split(".")
            sub_mol = [Chem.MolFromSmiles(smiles,sanitize=not synthon) for smiles in sub_smi]
            sub_mol_size = [(sub_smi[i], len(m.GetAtoms())) for i, m in enumerate(sub_mol) if m is not None]
            if len(sub_mol_size) > 0:
                return smi, patched_canonicalize(sorted(sub_mol_size,key=lambda x:x[1],reverse=True)[0][0],return_max_frag=False, synthon=synthon)
            else:
                return smi, ''
        else:
            return smi
    else:
        if return_max_frag:
            return '',''
        else:
            return ''

@hydra.main(config_path='../configs', config_name='config.yaml')
def search(config):
    '''
    Search for a given configuration.
    '''
    set_seed(config.general.seed)
    print(f'======== Seed: {config.general.seed}')
    # smiles, classes, starting_material_smiles = get_smiles_list(
    #     config,
    #     with_starting_material=True
    # )
    # # Store ground truth as:
    # ground_truth_data = {
    #     'mol_to_rxn_type': {
    #         'SMILES_of_B': 2,
    #         'SMILES_of_C': 3,
    #         'SMILES_of_A': 1,
    #     },
    #     'depth_to_types': {
    #         0: [1],
    #         1: [2, 3],
    #     }
    # }
    # smi = 'COc1ccc(P2(=S)SP(=S)(c3ccccc3)S2)cc1'
    # print(f'patched canonicalize: {patched_canonicalize(smi)}')
    # exit()
    targets, ground_truth_data, starting_materials = get_targets_and_reaction_type_from_routes(config)
    print(f'======= {len(targets)} molecules to run')
    for i, (smi, gt, starting_material_smi) in enumerate(zip(targets, ground_truth_data, starting_materials)):
        print(f"smi {smi}, gt {gt['mol_to_rxn_type']}, starting_material_smi {starting_material_smi}")
        # NOTE: we define these inside the loop on purpose, 
        # to reset any inner counters used in evaluation / by search
        # could look into a better way to do this later
        model = define_single_step_model(
            config,
            conditional_starting_materials=starting_materials,
            conditional_targets=None
        )
        search_algorithm = set_search_algorithm(config, model, target_smi=smi)
        target_idx = i+int(config.route_dataset.start_idx)
        #sm = config.classifier_guidance.dataset.separator+starting_material_smiles[i]+'</s>'
        print(f'======= running for molecule {target_idx}, smi {smi}')
        start_time = time.time()
        # if config.classifier_guidance.use_ground_truth_node_depth:
        #     config.classifier_guidance.enforce_starting_material_at_depth = set_node_depth(target_idx)
        if config.search.type=='desp':
            run_for_one_mol_desp(
                config,
                smi,
                target_idx,
                starting_material_smi=starting_material_smi
            )
            #run_for_one_mol_desp(config, 'COC(=O)CC12CCC(c3ccc(Br)cc3)(CC1)CO2', 0, 'COC(=O)C1(c2ccc(Br)cc2)CCC(=O)CC1')
        elif config.search.type=='retro_star':
            run_for_one_mol_retro_star(
                config,
                search_algorithm,
                smi,
                target_idx,
                ground_truth_data=gt,
                starting_material_smi=starting_material_smi
            )
        else:
            raise ValueError(f'Invalid search type: {config.search.type}')
        print(f'finished in {time.time() - start_time} seconds')

if __name__ == "__main__":
    search()
    