'''
   recomputing topk of old runs while taking stereochem into account.
'''
import time
import os
import sys
import datetime
import pathlib
import warnings
import random
import numpy as np
import torch
import wandb
import hydra
import logging
import copy
from torch.profiler import profile, record_function, ProfilerActivity
from src.utils import io_utils, mol
import multiprocessing
from functools import partial
from tqdm import tqdm

# A logger for this file
log = logging.getLogger(__name__)

from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities.warnings import PossibleUserWarning

from src.utils import setup
from hydra.core.hydra_config import HydraConfig
from src.utils import setup
from datetime import date
import re
from rdkit import Chem

warnings.filterwarnings("ignore", category=PossibleUserWarning)

parent_path = pathlib.Path(os.path.realpath(__file__)).parents[1]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ["WANDB__SERVICE_WAIT"] = "300"

def remove_atom_mapping_and_stereo(smi):
    m = Chem.MolFromSmiles(smi)
    [a.ClearProp('molAtomMapNumber') for a in m.GetAtoms()]
    Chem.RemoveStereochemistry(m)
    
    return Chem.MolToSmiles(m, canonical=True)

def remove_atom_mapping(smi):
    m = Chem.MolFromSmiles(smi)
    [a.ClearProp('molAtomMapNumber') for a in m.GetAtoms()]
    
    return Chem.MolToSmiles(m, canonical=True)

def remove_atom_mapping_from_mol(mol):
    [a.ClearProp('molAtomMapNumber') for a in mol.GetAtoms()]
    return mol

def undo_kekulize(smi):
    m = Chem.MolFromSmiles(smi)
    if m is None: return smi
    aromatic_smi = Chem.MolToSmiles(m, canonical=True)
    
    return aromatic_smi

def get_rxn_with_stereo(old_rxn, true_rxns_without_stereo, true_rxns_with_stereo):
    
    for  true_rxn_without_stereo, true_rxn_with_stereo in zip(true_rxns_without_stereo, true_rxns_with_stereo):
        if  old_rxn==true_rxn_without_stereo:
            return true_rxn_with_stereo
    
    print(f'old rxn {old_rxn}\n')
    print(f'true_rxn_with_stereo {true_rxn_with_stereo}\n')
    
    return None

def match_old_rxns(old_rxn, true_rxns_without_stereo, true_rxns_with_stereo, true_rxns_with_am_and_stereo):
    return (
        get_rxn_with_stereo(old_rxn, true_rxns_without_stereo, true_rxns_with_stereo),
        get_rxn_with_stereo(old_rxn, true_rxns_without_stereo, true_rxns_with_am_and_stereo)
    )

# Define global variables for shared objects
counter = None
lock = None

def init_globals(c, l):
    global counter
    global lock
    counter = c
    lock = l

def update_progress():
    global counter
    global lock
    with lock:
        counter.value += 1
        return counter.value

from collections import defaultdict

def group_indices(strings):
    # Create a dictionary to store the indices for each string
    indices = defaultdict(list)
    
    # Populate the dictionary with indices
    for index, string in enumerate(strings):
        indices[string].append(index)
    
    return indices

def process_reaction(i, matched_old_true_rxns_with_am, matched_old_true_rxns, sampled_rxns, cfg):
    """Inputs:
    i: index of the reaction
    matched_old_true_rxns_with_am: list of matched old reactions with atom mappings
    matched_old_true_rxns: list of matched old reactions without atom mappings
    sampled_rxns: list of sampled reactions. Should have atom mappings!
    cfg: configuration object
    Returns:
    topk_local: dictionary of topk for the current reaction
    topk_among_chiral_local: dictionary of topk for chiral reactions, for logging purposes
    topk_among_cistrans_local: dictionary of topk for cistrans reactions, for logging purposes
    chiral_reactions: number of chiral reactions (1/0)
    cistrans_reactions: number of cistrans reactions (1/0)
    progress: current progress counter
    """
    matched_rxn_with_am = matched_old_true_rxns_with_am[i]
    matched_rxn = matched_old_true_rxns[i]
    samples = sampled_rxns[i]
    
    matched_prod_mol = Chem.MolFromSmiles(matched_rxn_with_am.split('>>')[1])

    samples_with_chirality = []
    for j in range(len(samples)):
        prod_smi_without_am = matched_rxn.split('>>')[1]
        if "@" in matched_rxn_with_am or "/" in matched_rxn_with_am or "\\" in matched_rxn_with_am:
            sample = samples[j]
            
            prod_side_ams = set([a.GetAtomMapNum() for a in Chem.MolFromSmiles(sample.split('>>')[1]).GetAtoms()])
            sample_rct_mol = Chem.MolFromSmiles(sample.split('>>')[0])
            sample_prod_mol = Chem.MolFromSmiles(sample.split('>>')[1])

            if sample_rct_mol is not None:
                Chem.RemoveStereochemistry(sample_rct_mol) # This does some kind of sanitization, otherwise transferring the bond_dirs doesn't work reliably

                for a in sample_rct_mol.GetAtoms():# remove atom mappings that are not on the product side
                    if a.GetAtomMapNum() not in prod_side_ams:
                        a.ClearProp('molAtomMapNumber')
                mol.match_atom_mapping_without_stereo(sample_prod_mol, matched_prod_mol) # temporarily change the atom mapping in matched_prod_mol
                if "@" in matched_rxn_with_am:
                    sample_rct_mol = mol.transfer_chirality_from_product_to_reactant(sample_rct_mol, matched_prod_mol)
                if "/" in matched_rxn_with_am or "\\" in matched_rxn_with_am:
                    sample_rct_mol = mol.transfer_bond_dir_from_product_to_reactant(sample_rct_mol, matched_prod_mol)
                remove_atom_mapping_from_mol(sample_rct_mol)
                r_smiles = Chem.MolToSmiles(sample_rct_mol, canonical=True)
            else:
                r_smiles = ""
        else:
            r_smiles = samples[j].split('>>')[0]

        samples_with_chirality.append(r_smiles + ">>" + prod_smi_without_am)

    chiral_reactions = 0
    cistrans_reactions = 0
    topk_local = {k: 0 for k in cfg.test.topks}
    topk_among_chiral_local = {k: 0 for k in cfg.test.topks}
    topk_among_cistrans_local = {k: 0 for k in cfg.test.topks}

    if "@" in matched_rxn:
        chiral_reactions = 1
        for k in cfg.test.topks:
            topk_among_chiral_local[k] += int(matched_rxn in samples_with_chirality[:int(k)])
    if "/" in matched_rxn or "\\" in matched_rxn:
        cistrans_reactions = 1
        for k in cfg.test.topks:
            topk_among_cistrans_local[k] += int(matched_rxn in samples_with_chirality[:int(k)])

    for k in cfg.test.topks:
        topk_local[k] += int(matched_rxn in samples_with_chirality[:int(k)])

    # Update progress counter
    progress = update_progress()

    return topk_local, topk_among_chiral_local, topk_among_cistrans_local, chiral_reactions, cistrans_reactions, progress

@hydra.main(version_base='1.1', config_path='../configs', config_name=f'default')
def main(cfg: DictConfig):
    import os
    os.environ["OMP_NUM_THREADS"] = "1"  # To avoid interference with OpenMP
    
    parent_path = pathlib.Path(os.path.realpath(__file__)).parents[1]

    raw_data_path = os.path.join(parent_path, 'data', 'uspto-50k-block-15', 'raw', 'test.csv')
    old_samples_path = os.path.join(parent_path, 'OLD_SAMPLE_PATH')

    # load true raw data reactions
    print(f'reading raw data from {raw_data_path}\n')
    raw_true_rxns = open(raw_data_path, 'r').readlines()
    print(f'reading old samples from {old_samples_path}\n')
    sampled_rxns_in_eval_format = open(old_samples_path, 'r').read()
    sampled_rxns_blocks = io_utils.read_saved_reaction_data(sampled_rxns_in_eval_format)

    true_rxns_with_am_and_stereo, true_rxns_with_stereo, true_rxns_without_stereo = [], [], []
    for rxn in raw_true_rxns:
        reactants = rxn.split('>>')[0]
        products = rxn.split('>>')[1]
        true_rxns_with_am_and_stereo.append(rxn)
        rxn_with_stereo = remove_atom_mapping(reactants) + '>>' + remove_atom_mapping(products)
        true_rxns_with_stereo.append(rxn_with_stereo)
        rxn_without_stereo = remove_atom_mapping_and_stereo(reactants) + '>>' + remove_atom_mapping_and_stereo(products)
        true_rxns_without_stereo.append(rxn_without_stereo)
        
    raw_old_true_rxns = [sample[0] for sample in sampled_rxns_blocks]
    old_true_rxns = []
    for rxn in raw_old_true_rxns:
        reactants = rxn.split('>>')[0]
        products = rxn.split('>>')[1]
        old_true_rxn = remove_atom_mapping(reactants) + '>>' + remove_atom_mapping(products)
        old_true_rxns.append(old_true_rxn)

    raw_sampled_rxns = [sample[1] for sample in sampled_rxns_blocks]
    sampled_rxns = []
    for sample in raw_sampled_rxns:
        sampled_rxns_per_true_rxn = []
        for sample_info in sample:
            rxn = sample_info[0] # sample_info also contains numbers like elbo etc
            reactants = rxn.split('>>')[0]
            products = rxn.split('>>')[1]
            sampled_rxn = undo_kekulize(reactants) + '>>' + undo_kekulize(products)
            sampled_rxns_per_true_rxn.append(sampled_rxn)
        
        # For the case where the 'samples' file was encoded in a way that contained duplicates, we group them together here and calculate the weighted probability
        rxn_indices_grouped = group_indices(sampled_rxns_per_true_rxn) # group indices of duplicates
        new_counts = {}
        new_elbos = {}
        for rxn, indices in rxn_indices_grouped.items():
            counts = sum([int(sample[i][1][-2]) for i in indices])
            elbos = sum([float(sample[i][1][0]) for i in indices]) / len(indices)
            new_elbos[rxn] = elbos
            new_counts[rxn] = counts

        # recalculate the weighted probability
        sum_exp_elbo = sum(np.exp(-elbo) for elbo in new_elbos.values())
        sum_counts = sum(new_counts.values())
        new_weighted_probs = {}
        for rxn in sampled_rxns_per_true_rxn:
            exp_elbo = np.exp(-new_elbos[rxn])
            weighted_prob = (exp_elbo / sum_exp_elbo) * cfg.test.sort_lambda_value + (new_counts[rxn] / sum_counts) * (1 - cfg.test.sort_lambda_value)
            new_weighted_probs[rxn] = weighted_prob
        
        # sort the list of reactions for the current product based on weighted_prob
        new_sampled_rxns_per_true_rxn = sorted(list(set(sampled_rxns_per_true_rxn)), key=lambda x: new_weighted_probs[x], reverse=True)

        sampled_rxns.append(new_sampled_rxns_per_true_rxn)

    # match old and true_rxns
    print(f'matching old and true reactions\n')
    matched_old_true_rxns = []
    matched_old_true_rxns_with_am = []
    # Define the partial function
    partial_match_old_rxns = partial(
        match_old_rxns,
        true_rxns_without_stereo=true_rxns_without_stereo,
        true_rxns_with_stereo=true_rxns_with_stereo,
        true_rxns_with_am_and_stereo=true_rxns_with_am_and_stereo
    )
    # Create a pool of worker processes
    with multiprocessing.Pool(processes=8) as pool:
        # Map the match_old_rxns function to the old_true_rxns list
        results = pool.map(partial_match_old_rxns, old_true_rxns)

    # Process the results
    for i, (matched_rxn, matched_rxn_with_am) in enumerate(results):
        matched_old_true_rxns.append(matched_rxn)
        matched_old_true_rxns_with_am.append(matched_rxn_with_am)
        print(f'i {i}\n')

    assert None not in matched_old_true_rxns, 'Some old reactions could not be matched with true reactions.'
    
    import time
    t0 = time.time()
    # calculate topk
    # assumes old_samples are sorted
    print(f'calculating topk\n')
    topk = {k:0 for k in cfg.test.topks}
    topk_among_chiral = {k:0 for k in cfg.test.topks}
    topk_among_cistrans = {k:0 for k in cfg.test.topks}
    total_chiral_reactions = 0
    total_cistrans_reactions = 0
    num_processes = multiprocessing.cpu_count()
    num_processes = 1

    manager = multiprocessing.Manager()
    global counter, lock
    counter = manager.Value('i', 0)
    lock = manager.Lock()

    print("num_processes", num_processes)
    partial_process_reaction = partial(
        process_reaction,
        matched_old_true_rxns_with_am=matched_old_true_rxns_with_am,
        matched_old_true_rxns=matched_old_true_rxns,
        sampled_rxns=sampled_rxns, cfg=cfg)
    
    total_tasks = len(matched_old_true_rxns)
    with multiprocessing.Pool(processes=num_processes, initializer=init_globals, initargs=(counter, lock)) as pool:
        results = []
        with tqdm(total=total_tasks) as pbar:
            for result in pool.imap(partial_process_reaction, range(total_tasks)):
                results.append(result[:5])  # Append only the topk_local, topk_among_chiral_local, and chiral_reactions
                pbar.update(result[-1] - pbar.n)  # Update the progress bar

    for topk_local, topk_among_chiral_local, topk_among_cistrans_local, chiral_reactions, cistrans_reactions in results:
        for k in cfg.test.topks:
            topk[k] += topk_local[k]
            topk_among_chiral[k] += topk_among_chiral_local[k]
            topk_among_cistrans[k] += topk_among_cistrans_local[k]
        total_chiral_reactions += chiral_reactions
        total_cistrans_reactions += cistrans_reactions

    print(f'time taken {time.time()-t0}\n')
    print(f'unnormalized topk {topk}\n')   
    print(f'unnormalized topk_among_chiral {topk_among_chiral}\n')
    print(f'total_chiral_reactions {total_chiral_reactions}\n')
    topk = {k:v/len(sampled_rxns) for k,v in topk.items()}
    topk_among_chiral = {k:v/total_chiral_reactions for k,v in topk_among_chiral.items()}
    topk_among_cistrans = {k:v/total_cistrans_reactions for k,v in topk_among_cistrans.items()}
    print(f'normalized topk {topk}\n')  
    print(f'normalized topk_among_chiral {topk_among_chiral}\n')
    print(f'normalized topk_among_cistrans {topk_among_cistrans}\n')

if __name__ == '__main__':
    import sys
    gettrace = getattr(sys, 'gettrace', None)
    if gettrace is None:
        try:
            main()
        except Exception as e:
            log.exception("main crashed. Error: %s", e)
    else:
        main()


