'''
    Helper functions for handling data.
'''
import os
from ast import literal_eval
from rxn_insight.reaction import Reaction
from rxnmapper import RXNMapper
os.environ['PYTHONWARNINGS'] = 'ignore'
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore", module="rdkit")
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning) 
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=RuntimeWarning)

# At the very top of your script, before any other imports
import sys
from unittest.mock import MagicMock

# Mock the problematic RDKit drawing modules
sys.modules['rdkit.Chem.Draw.rdMolDraw2D'] = MagicMock()
sys.modules['rdkit.Chem.Draw'] = MagicMock()

from pathlib import Path
import zipfile
import requests
import time
import difflib
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs
from rdkit.Chem import rdChemReactions
import numpy as np
import torch
import re
import os
import json
import sys
from itertools import chain
from typing import List
import pandas as pd
import pickle
import random
from matplotlib import pyplot as plt
from rdchiral import template_extractor
from itertools import permutations
from sklearn.model_selection import train_test_split
import Levenshtein
from rdkit import Chem
from rdkit.Chem import DataStructs
from rdkit.Chem import rdFMCS
# sys.path.append(os.path.join(os.environ['CONDA_PREFIX'],'share','RDKit','Contrib'))
# from SA_Score import sascorer
# from NP_Score import npscorer      

from syntheseus.search.graph.and_or import AndNode, OrNode

from multiguide.syntheseus.visualize import visualize_andor
#from rxnmapper import RXNMapper

from multiguide.helpers import PROJECT_ROOT
from multiguide.property.property_predictor import PropertyPredictor
from multiguide.desp.DESP import DESP
from multiguide.desp.retro_predictor import RetroPredictor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ALL_MODELS = "acute_tox dili neuro nephro respi cardio carcino immuno mutagen cyto bbb eco clinical nutri nr_ahr nr_ar nr_ar_lbd nr_aromatase nr_er nr_er_lbd nr_ppar_gamma sr_are sr_hse sr_mmp sr_p53 sr_atad5 mie_thr_alpha mie_thr_beta mie_ttr mie_ryr mie_gabar mie_nmdar mie_ampar mie_kar mie_ache mie_car mie_pxr mie_nadhox mie_vgsc mie_nis CYP1A2 CYP2C19 CYP2C9 CYP2D6 CYP3A4 CYP2E1"

toxicity_groups = {
    'tox_endpoints': ['carcino', 'immuno', 'mutagen', 'cyto', 'eco', 'nutri', 'clinical', 'bbb'],
    'tox_pathways': ['sr_', 'nr_'],
    'mie': ['mie_'],
    'metabolism': ['CYP'],
    'organ': ['dili', 'cardio', 'nephro', 'respi', 'neuro']
}

reaction_types = [
    'Heteroatom alkylation and arylation',
    'Acylation and related processes', 
    'C-C bond formation',
    'Heterocycle formation',
    'Protections',
    'Reductions',
    'Oxidations', 
    'Deprotections',
    'Functional group interconversion (FGI)',
    'Functional group addition (FGA)'
]

class_to_idx = {
    'Heteroatom Alkylation and Arylation': 0,
    'Acylation': 1,
    'C-C Coupling': 2,
    'Aromatic Heterocycle Formation': 3,
    'Protection': 4,
    'Deprotection': 5,
    'Reduction': 6,
    'Oxidation': 7,
    'Functional Group Addition': 8,
    'Functional Group Interconversion': 9,
    'Miscellaneous': 10,
    '-1': -1
}

from dataclasses import dataclass
from typing import Optional, List, Dict, Any


@dataclass
class ReactionData:
    """Container for reaction data from batch."""
    reactants: str
    product: str
    class_idx: str
    original_starting_material: Optional[str]
    most_similar_reactants: Optional[str]
    least_similar_reactants: Optional[str]
    most_similar_reactants_similarity: Optional[float]
    least_similar_reactants_similarity: Optional[float]
    similarity_to_target: Optional[float]
    conditional_starting_material: str
    original_target: Optional[str]
    conditional_target: Optional[str]
    batch_index: int  # Original index in dataset

def parse_batch_to_reaction_data(
    batch: List[tuple],
    start_idx: int
) -> List[ReactionData]:
    """
        Convert raw batch tuples to structured ReactionData objects.
    """
    return [
        ReactionData(
            reactants=reactants,
            product=product,
            class_idx=class_idx,
            original_starting_material=original_starting_material,
            conditional_starting_material=conditional_starting_material,
            original_target=original_target,
            conditional_target=conditional_target,
            batch_index=start_idx+i,
            most_similar_reactants=most_similar_reactants,
            least_similar_reactants=least_similar_reactants,
            most_similar_reactants_similarity=most_similar_reactants_similarity,
            least_similar_reactants_similarity=least_similar_reactants_similarity,
            similarity_to_target=similarity_to_target
        )
        for i, (reactants, product, class_idx, original_starting_material,
                conditional_starting_material, original_target,
                conditional_target, most_similar_reactants,
                least_similar_reactants, most_similar_reactants_similarity,
                least_similar_reactants_similarity, similarity_to_target) in enumerate(batch)
    ]

def update_config_for_reaction(config, reaction: ReactionData) -> None:
    """Update config object with reaction data for evaluation."""
    config.single_step_evaluation.product_smi = reaction.product
    config.single_step_evaluation.true_reactants = reaction.reactants
    config.single_step_evaluation.original_starting_material = reaction.original_starting_material
    config.single_step_evaluation.original_target = reaction.original_target
    config.classifier_guidance.target_class_index = reaction.class_idx
    config.single_step_evaluation.rxn_class = reaction.class_idx


def flatten_predictions(
    df_dict: Dict[str, Any],
    reaction: ReactionData,
    separator: str
) -> List[Dict[str, Any]]:
    """
    Flatten prediction dictionary into list of row dictionaries.
    
    Args:
        df_dict: Dictionary with prediction results (may contain lists or scalars)
        reaction: ReactionData object with metadata
        separator: Starting material separator string
        
    Returns:
        List of dictionaries, one per prediction
    """
    if not df_dict or 'reactant_predictions' not in df_dict:
        return []
    
    num_predictions = len(df_dict['reactant_predictions'])
    
    # Add metadata fields
    metadata = {
        'starting_material_separator': separator,
        'original_starting_material': reaction.original_starting_material,
        'original_target': reaction.original_target,
        'batch_index': reaction.batch_index
    }
    
    rows = []
    for j in range(num_predictions):
        row = {}
        
        # Add prediction data
        for key, value in df_dict.items():
            if isinstance(value, list) and len(value) == num_predictions:
                row[key] = value[j]
            else:
                row[key] = value  # Scalar value
        
        # Add metadata
        row.update(metadata)
        rows.append(row)
    
    return rows

def get_sorted_cano_smiles(smilies_list):
    '''
    Get sorted smiles
    '''
    return '.'.join(sorted([clear_atom_map(smi) for smi in smilies_list]))

def remove_duplicated_reactions(df, print_duplicates=False):
    '''
    Remove duplicates from the dataset
    '''
    # get (ordered) reactants and products without atom-mapping
    df = get_sorted_cano_reactions(df, reaction_column_name='reactants>reagents>production')
    # drop duplicates
    # Find duplicates (keeping the first occurrence)
    duplicates = df[
        df.duplicated(subset=['sorted_cano_reactants', 'sorted_cano_products'], keep='first')
    ]
    # Print the indices of duplicates
    print(f'Indices of duplicates: {duplicates.index.tolist()}')
    print(f'Number of duplicates: {len(duplicates)}')
    # Optionally, print more details about the duplicates
    if print_duplicates and len(duplicates) > 0:
        print('\nDuplicate rows:')
        print(duplicates[['sorted_cano_reactants', 'sorted_cano_products']])
    # Now drop the duplicates
    df = df.drop_duplicates(subset=['sorted_cano_reactants', 'sorted_cano_products'])
    return df

def get_sorted_cano_reactions(df, reaction_column_name):
    '''
    Get ordered canonicalized reactions
    '''
    df['sorted_cano_reactants'] = df[reaction_column_name].apply(
        lambda x: get_sorted_cano_smiles(
            x.split('>>')[0].split('.')
        )
    )
    df['sorted_cano_products'] = df[reaction_column_name].apply(
        lambda x: get_sorted_cano_smiles(
            x.split('>>')[1].split('.')
        )
    )
    df['sorted_cano_reactions'] = df['sorted_cano_reactants'] + '>>' + df['sorted_cano_products']
    return df

def remove_overlaps_with_subset(df, cfg, reference_subset):
    '''
    Remove overlaps with another subset
    '''
    reference_path = os.path.join(
        PROJECT_ROOT,'data', cfg.reaction_dataset.data_dir, 'raw', f'{reference_subset}.csv'
    )
    reference_df = pd.read_csv(reference_path)
    reference_df = get_sorted_cano_reactions(
        reference_df, reaction_column_name='reactants>reagents>production'
    )
    if 'sorted_cano_reactions' not in df.columns:
        df = get_sorted_cano_reactions(
            df, reaction_column_name='reactants>reagents>production'
        )
    is_overlap = df['sorted_cano_reactions'].isin(reference_df['sorted_cano_reactions'])
    print(f'Number of overlaps: {is_overlap.sum()}')
    return df[~is_overlap]

def remove_overlaps_with_other_subsets(df, cfg):
    '''
    Remove overlaps with other subsets
    '''
    if cfg.reaction_dataset.subset == 'train':
        print('Keeping all train reactions')
    elif cfg.reaction_dataset.subset == 'val':
        # remove duplicates from train
        print('Removing overlaps with train')
        df = remove_overlaps_with_subset(df, cfg=cfg, reference_subset='train')
    elif cfg.reaction_dataset.subset == 'test':
        # remove duplicates from train
        print('Removing overlaps with train')
        df = remove_overlaps_with_subset(df, cfg=cfg, reference_subset='train')
        # remove duplicates from val
        print('Removing overlaps with val')
        df = remove_overlaps_with_subset(df, cfg=cfg, reference_subset='val')
    return df

def extract_reaction_type(rxn_insight_info):
    '''
    Extract reaction type from rxn insight info
    '''
    if rxn_insight_info is not None:
        if isinstance(rxn_insight_info, dict):
            return rxn_insight_info['CLASS']
        else:
            return literal_eval(rxn_insight_info)['CLASS']
    else:
        return None

def get_rxn_insight_info(rxn_smi):
    try:
        rxn = Reaction(rxn_smi)
        ri = rxn.get_reaction_info()
        return ri
    except Exception as e:
        print(f'error: {e}')
        return None

def add_rxn_insight_information(df):
    '''
    Add rxn insight information to the dataset
    '''
    df['rxn_insight_info'] = df['reactants>reagents>production'].apply(get_rxn_insight_info)
    df['reaction_type'] = df['rxn_insight_info'].apply(lambda x: class_to_idx[extract_reaction_type(x)])
    return df

def find_reactant_by_similarity_rank(reactants, product, cfg, ranking='most'):
    '''
    Get most and least similar reactants to the product
    '''
    reactant_similarities = [
        get_similarity_for_one_pair(
            reactant,
            product,
            similarity_type=cfg.reaction_dataset.similarity_type,
            combination_weight=cfg.reaction_dataset.combination_weight
        )
        for reactant in reactants
    ]
    if ranking == 'most':
        extreme_idx = reactant_similarities.index(max(reactant_similarities))
    elif ranking == 'least':
        extreme_idx = reactant_similarities.index(min(reactant_similarities))
    else:
        raise ValueError(f'Invalid ranking: {ranking}. Must be "most" or "least"')
    extreme_reactant = reactants[extreme_idx]
    extreme_similarity = reactant_similarities[extreme_idx]
    return extreme_reactant, extreme_similarity

def assign_similarity(df, cfg):
    '''
    Assign tanimoto similarity to the dataset
    '''
    most_similar_reactants = df.apply(
        lambda x: find_reactant_by_similarity_rank(
            x['sorted_cano_reactants'].split('.'), x['sorted_cano_products'], cfg, ranking='most'
        ),
        axis=1
    )
    least_similar_reactants = df.apply(
        lambda x: find_reactant_by_similarity_rank(
            x['sorted_cano_reactants'].split('.'), x['sorted_cano_products'], cfg, ranking='least'
        ),
        axis=1
    )
    df['most_similar_reactants'] = most_similar_reactants.apply(lambda x: x[0])
    df['most_similar_reactants_similarity'] = most_similar_reactants.apply(lambda x: x[1])
    df['least_similar_reactants'] = least_similar_reactants.apply(lambda x: x[0])
    df['least_similar_reactants_similarity'] = least_similar_reactants.apply(lambda x: x[1])
    return df

def get_token_prefix_similarity_vectorized(generated_seqs, target_seq):
    """
    Vectorized token-level prefix similarity computation
    
    Args:
        generated_seqs: tensor of shape (batch_size, seq_len)
        target_seq: tensor of shape (target_len,)
    
    Returns:
        similarities: tensor of shape (batch_size,)
    """
    batch_size, gen_seq_len = generated_seqs.shape
    target_len = target_seq.shape[0]
    
    # Determine the comparison length for each sequence
    min_len = min(gen_seq_len, target_len)
    
    if min_len == 0:
        return torch.zeros(batch_size, device=generated_seqs.device)
    
    # Truncate sequences to comparison length
    gen_truncated = generated_seqs[:, :min_len]  # (batch_size, min_len)
    target_truncated = target_seq[:min_len].unsqueeze(0)  # (1, min_len)
    
    # Element-wise comparison
    matches = (gen_truncated == target_truncated)  # (batch_size, min_len)
    
    # Find the longest prefix match for each sequence
    # Use cumulative product to find where the first mismatch occurs
    cumulative_matches = torch.cumprod(matches, dim=1)  # (batch_size, min_len)
    
    # Count consecutive matches from the beginning
    # Sum along the sequence dimension gives us the prefix length
    prefix_lengths = cumulative_matches.sum(dim=1)  # (batch_size,)
    
    # Calculate similarity as fraction of target length
    similarities = prefix_lengths.float() / target_len
    
    return similarities

def get_token_prefix_similarity(generated_tokens, target_tokens):
    """
    Token-level prefix similarity - most appropriate for autoregressive generation
    """
    # Find longest common prefix at token level
    common_prefix_len = 0
    min_len = min(len(generated_tokens), len(target_tokens))
    for i in range(min_len):
        if generated_tokens[i] == target_tokens[i]:
            common_prefix_len += 1
        else:
            break
    # Similarity as fraction of target length
    similarity = common_prefix_len / len(target_tokens)
    return similarity

def get_similarity_for_one_pair(main_target, starting_material,
                                similarity_type, combination_weight=1):
    '''
    Get similarity between main target and starting material.
    '''
    if similarity_type == 'tanimoto':
        similarity = get_tanimoto(main_target, starting_material)
    elif similarity_type == 'fms':
        similarity = get_fms(main_target, starting_material)
    elif similarity_type == 'fms_tanimoto':
        similarity =  (1 - combination_weight) * get_tanimoto(main_target, starting_material)
        similarity += combination_weight *  get_fms(main_target, starting_material)
    else:
        raise ValueError(f'Similarity type {similarity_type} not supported')
    return similarity

def get_similarity(main_target, starting_material, similarity_type, combination_weight=1):
    '''
    Get similarity between main target and starting material.
    Args:
        main_target (str): main target smiles
        starting_material (list): list of starting material smiles
    '''
    similarities = []
    for sm in starting_material:
        similarity = get_similarity_for_one_pair(main_target, sm, similarity_type, combination_weight)
        similarities.append((main_target, sm, similarity))
    return similarities

def get_starting_material_from_route(route):
    '''
    Get starting material from route.
    Args:
        route (list): list of reaction smiles
    Returns:
        list: list of starting material smiles
    '''
    bb_mol2idx = os.path.join(PROJECT_ROOT,
                            'data',
                            'desp_data',
                            'origin_dict.csv')
    df = pd.read_csv(bb_mol2idx, index_col=0)
    bbs = df['mol'].tolist() # building blocks (bbs) is same as starting material
    starting_material_in_route = []
    for reaction in route:
        reactants = reaction.split('>>')[1].split('.') # assumed to work with retro reactions
        reactants_are_bb = [r for r in reactants if r in bbs]
        if len(reactants_are_bb) > 0:
            starting_material_in_route.extend(reactants_are_bb)
    return starting_material_in_route


def turn_seq_to_ids(seq, onmt_checkpoint_path, use_unk=False):
    vocab = get_vocab_from_trained_model(onmt_checkpoint_path)
    tokens_str, _ = smi_tokenizer(seq)
    tokens = tokens_str.split(' ')
    token_ids = []
    for token in tokens:
        if token not in vocab:
            if use_unk:
                token_ids.append(vocab.index('<unk>'))
            else:
                print(f'token {token} not in vocab for sequence {seq}')
                return None
        else:
            token_ids.append(vocab.index(token))
    tokens_vec = torch.tensor(token_ids, device=device)
    return tokens_vec
    
def get_tanimoto(main_target, sm):
    ''' 
    Get tanimoto distance between main target and starting material.
    Args:
        main_target (str): main target smiles
        starting_material (list): list of starting material smiles
    Returns:
        list: list of tuples of (main_target, starting_material, tanimoto_distance)
    '''
    mol1 = Chem.MolFromSmiles(sm)
    mol2 = Chem.MolFromSmiles(main_target)
    if mol1 is None or mol2 is None:
        return -1
    # Generate fingerprints
    fp1 = Chem.RDKFingerprint(mol1)
    fp2 = Chem.RDKFingerprint(mol2)
    # Calculate Tanimoto similarity
    tanimoto_similarity = DataStructs.TanimotoSimilarity(fp1, fp2)
    return tanimoto_similarity

def choose_closest_starting_material(reactants, product):
    '''
        Choose the closest starting material to the product.
    '''
    # compute tanimoto distance between reactants and product
    tanimoto_distances = [get_tanimoto(reactant, product) for reactant in reactants]
    closest_starting_material = reactants[tanimoto_distances.index(max(tanimoto_distances))]
    return closest_starting_material
    
def get_fms(main_target, sm):
    '''
    Get fms distance between main target and starting material.
    Args:
        main_target (str): main target smiles
        starting_material (list): list of starting material smiles
    '''
    # Find the Maximum Common Substructure (MCS) between the query and precursor molecules
    sm_mol = Chem.MolFromSmiles(sm)
    main_target_mol = Chem.MolFromSmiles(main_target)
    mcs_result = rdFMCS.FindMCS(
        mols=[main_target_mol, sm_mol],
        matchChiralTag=True,
        bondCompare=rdFMCS.BondCompare.CompareOrderExact,
        ringCompare=rdFMCS.RingCompare.StrictRingFusion,
        completeRingsOnly=True
    )
    # Compute the weighted MCS score based on the fraction of matching atoms
    return  max(0, (mcs_result.numAtoms / sm_mol.GetNumAtoms()))
    
def get_levenstein_similarity(str1, str2):
    '''
        Get the Levenshtein distance between two strings.
    '''
    distance = Levenshtein.distance(str1, str2)
    similarity = 1 - (distance / max(len(str1), len(str2)))
    return similarity

def find_differences_and_matches(str1, str2, min_length=3):
    '''
        Find the differences and matches between two (smiles)strings.
    '''
    matcher = difflib.SequenceMatcher(None, str1, str2)
    matches = []
    # Get all matching blocks
    for match in matcher.get_matching_blocks():
        if match.size >= min_length:
            substring = str1[match.a:match.a + match.size]
            matches.append((match.a, 
                            match.a + match.size, 
                            match.b, 
                            match.b + match.size, 
                            substring))
    # Sort matches by position in str1
    matches.sort(key=lambda x: x[0])
    # Extract non-matching parts
    str1_diffs = []
    str2_diffs = []
    prev_end1, prev_end2 = 0, 0
    for start1, end1, start2, end2, match_str in matches:
        # Add non-matching parts before this match
        if start1 > prev_end1:
            str1_diffs.append(str1[prev_end1:start1])
        if start2 > prev_end2:
            str2_diffs.append(str2[prev_end2:start2])
        prev_end1, prev_end2 = end1, end2
    # Add remaining parts after last match
    if prev_end1 < len(str1):
        str1_diffs.append(str1[prev_end1:])
    if prev_end2 < len(str2):
        str2_diffs.append(str2[prev_end2:])
    return {
        'matches': [(m[4], len(m[4])) for m in matches], # (substring, length)
        #'matches': [(m[4], m[3]) for m in matches],  # (substring, length)
        'str1_diffs': str1_diffs,
        'str2_diffs': str2_diffs
    }

def compare_multicomponent_smiles(true_smiles, pred_smiles, min_length=3):
    '''
        Compare two multicomponent SMILES strings.
    '''
    true_parts = true_smiles.split('.')
    pred_parts = pred_smiles.split('.')
    # if len(true_parts) != len(pred_parts):
    #     return None  # Different number of components
    best_result = None
    best_score = 0
    # Try all permutations of predicted parts
    for perm in permutations(pred_parts):
        reordered_pred = '.'.join(perm)
        result = find_differences_and_matches(true_smiles, reordered_pred, min_length)
        # Score by total match length
        score = sum(length for _, length in result['matches'])
        if score > best_score:
            best_score = score
            best_result = result
    return best_result

def compare_reactant_smiles(smiles1, smiles2):
    set1 = set([clear_atom_map(sm) for sm in smiles1.split('.')])
    set2 = set([clear_atom_map(sm) for sm in smiles2.split('.')])
    return set1 == set2

def get_batch(config):
    '''
        This function gets a batch of molecules from the dataset.
    '''
    # get the batch of molecules
    # read the test file
    if config.single_step_evaluation.data_dir=='uspto_50k/processed' \
        or config.single_step_evaluation.data_dir=='uspto_50k_debug/processed' \
        or config.single_step_evaluation.data_dir=='uspto_190/processed'\
        or config.single_step_evaluation.data_dir=='route_similarity_data/processed':
        df = pd.read_csv(os.path.join(PROJECT_ROOT, 'data', config.single_step_evaluation.data_dir, config.single_step_evaluation.subset))
        if 'conditional_starting_material' not in df.columns:
            # if config.classifier_guidance.with_reactants_as_starting_material:
            #     original_starting_material = df['sorted_cano_reactions'].apply(lambda x: x.split('>>')[0])
            # else:
            #     original_starting_material = df.apply(lambda x: choose_closest_starting_material(\
            #                                             x['sorted_cano_reactions'].split('>>')[0].split('.'), \
            #                                             x['sorted_cano_reactions'].split('>>')[1]), axis=1)
            df['original_starting_material'] = df[config.classifier_guidance.dataset.conditional_starting_material_column]
            df['conditional_starting_material'] = df['original_starting_material'].apply(lambda x: '<s>'+x+config.classifier_guidance.dataset.separator)
        if 'conditional_target' not in df.columns:
            # TODO: this is not used at all now
            if 'main_target' not in df.columns:
                df['main_target'] = df['sorted_cano_products']
            df['conditional_target'] = df['main_target'].apply(lambda x: config.classifier_guidance.dataset.separator+x+'</s>')
            # NOTE: adding similarity to target, could be relevant when comparing samples to ground truth
            df['similarity_to_target'] = -1
        batch = df.iloc[config.single_step_evaluation.start_idx:config.single_step_evaluation.end_idx]\
                .apply(lambda x: (
                    x['sorted_cano_reactants'],
                    x['sorted_cano_products'],
                    x['reaction_type'],
                    x['original_starting_material'],
                    x['conditional_starting_material'],
                    x['main_target'],
                    x['conditional_target'],
                    x['most_sm'],
                    #x['most_sm_to_reactants_max_reactant'],
                    x['least_sm'],
                    #x['least_sm_to_reactants_max_reactant'],
                    x['most_sm_to_reactants_similarity_max'],
                    x['least_sm_to_reactants_similarity_max'],
                    #x['most_similar_reactants'],
                    #x['least_similar_reactants'],
                    # x['most_similar_reactants_similarity'],
                    # x['least_similar_reactants_similarity'],
                    x['similarity_to_target']
                ), axis=1).tolist()
    elif config.single_step_evaluation.data_dir=='uspto_190/first_reactions' or config.single_step_evaluation.data_dir=='uspto_190/train_unique_reactions':
        df = pd.read_csv(os.path.join(PROJECT_ROOT, 
                            'data', 
                            config.single_step_evaluation.data_dir, 
                            config.single_step_evaluation.subset))
        # get conditional starting material for each target

        if 'main_target' not in df.columns:
            raise ValueError('main_target not in df.columns')

        if 'conditional_starting_material' in df.columns:
            raise ValueError('conditional_starting_material not in df.columns')

        df['original_target'] = df['main_target']
        df['conditional_target'] = df['original_target'].apply(lambda x: config.classifier_guidance.dataset.separator+x+'</s>', axis=1)
        df['original_starting_material'] = df['conditional_starting_material']
        df['conditional_starting_material'] = df.apply(lambda x: config.classifier_guidance.dataset.separator+x['original_starting_material']+'</s>', axis=1)

        batch = df.iloc[config.single_step_evaluation.start_idx:config.single_step_evaluation.end_idx]\
                .apply(lambda x: (x['reactant'], x['product'], x['rxn_insight_class'], \
                            x['original_starting_material'], x['conditional_starting_material'], \
                                 x['original_target'], x['conditional_target']), axis=1)\
                .tolist()
    elif config.single_step_evaluation.data_dir=='uspto_190/first_reactions_with_targets' \
        or config.single_step_evaluation.data_dir=='uspto_190/first_reactions_nonlinear_with_targets':
        df = pd.read_csv(os.path.join(PROJECT_ROOT, 
                                        'data', 
                                        config.single_step_evaluation.data_dir, 
                                        config.single_step_evaluation.subset))
        # get conditional starting material for each target
        if 'conditional_starting_material' not in df.columns:
            raise ValueError('conditional_starting_material not in df.columns')
        # TODO: change this to conditional_target for consistency
        if 'main_target' not in df.columns:
            raise ValueError('main_target not in df.columns')

        df['original_target'] = df['main_target']
        df['conditional_target'] = df.apply(lambda x: config.classifier_guidance.dataset.separator+x['original_target']+'</s>', axis=1)
        df['original_starting_material'] = df['conditional_starting_material']
        df['conditional_starting_material'] = df.apply(lambda x: config.classifier_guidance.dataset.separator+x['original_starting_material']+'</s>', axis=1)

        batch = df.iloc[config.single_step_evaluation.start_idx:config.single_step_evaluation.end_idx]\
                .apply(lambda x: (x['reactant'], x['product'], x['rxn_insight_class'], x['original_starting_material'],\
                     x['conditional_starting_material'], x['original_target'], x['conditional_target'] ), axis=1)\
                .tolist()
    elif config.single_step_evaluation.data_dir=='uspto_50k/no_solutions':
        df = pd.read_csv(os.path.join(PROJECT_ROOT, 'data', config.single_step_evaluation.data_dir, config.single_step_evaluation.subset))
        if 'conditional_starting_material' not in df.columns:
            if config.classifier_guidance.with_reactants_as_starting_material:
                conditional_starting_material = df['reactants>reagents>production'].apply(lambda x: '<s>'+x.split('>>')[0]+'</s>')
            else:
                conditional_starting_material = df.apply(lambda x: '<s>'+choose_closest_starting_material(\
                                                        x['reactants>reagents>production'].split('>>')[0].split('.'), \
                                                        x['reactants>reagents>production'].split('>>')[1])+config.classifier_guidance.dataset.separator, axis=1)
        
            df['conditional_starting_material'] = conditional_starting_material

        if 'conditional_target' not in df.columns:
            # TODO: this is not used at all now
            df['conditional_target']  = df.apply(lambda x: config.classifier_guidance.dataset.separator+x['product']+'</s>', axis=1)

        batch = df.iloc[config.single_step_evaluation.start_idx:config.single_step_evaluation.end_idx]\
                .apply(lambda x: (x['reaction'].split('>>')[0], x['reaction'].split('>>')[1], x['true_class'], \
                    x['conditional_starting_material'], x['conditional_target']), axis=1)\
                .apply(lambda x: (clear_atom_map(x[0]), clear_atom_map(x[1]), x[2], x[3], x[4]))\
                .tolist()
    elif config.single_step_evaluation.data_dir=='uspto_190/reactions_with_starting_material':
        df = pd.read_csv(os.path.join(PROJECT_ROOT, 'data', config.single_step_evaluation.data_dir, config.single_step_evaluation.subset))
        df['original_target'] = None
        df['conditional_target'] = None
        df['original_starting_material'] = df['conditional_starting_material']
        df['conditional_starting_material'] = df.apply(lambda x: '<s>'+x['original_starting_material']+'.', axis=1)
        batch = df.iloc[config.single_step_evaluation.start_idx:config.single_step_evaluation.end_idx]\
                .apply(lambda x: (x['reactant'], x['product'], x['rxn_insight_class'], x['original_starting_material'],\
                     x['conditional_starting_material'], x['original_target'], x['conditional_target'] ), axis=1)\
                .tolist()
    else:
        raise ValueError(f'Invalid data directory: {config.single_step_evaluation.data_dir}')
    return batch
    
def compute_class_info(df):
    '''
        Compute the class information for the dataframe.
    '''
    #classifier_output_to_rxn_insight_class_accuracy = ((df['pred_class'] == df['rxn_insight_class'])).mean()
    classifier_output_to_true_class = ((df['pred_class'] == df['true_class'])).mean()
    #rxn_insight_class_to_true_class = ((df['rxn_insight_class'] == df['true_class'])).mean()
    #return classifier_output_to_rxn_insight_class_accuracy, classifier_output_to_true_class, rxn_insight_class_to_true_class
    return classifier_output_to_true_class

def compute_average_topk_and_coverage(df, topk_key='new_topk', coverage_key='round_trip_accuracy'):
    '''
        Compute the average topk and roundtrip accuracy for the dataframe.
    '''
    num_products = df['product_smi'].nunique()
    topk = {1: 0, 3: 0, 5: 0, 10: 0, 50: 0, 100: 0}
    # Original logic for exact matches
    topk_with_rank = df.groupby('product_smi').apply(
            lambda x: pd.DataFrame({topk_key: x.reset_index(drop=True)[topk_key]==1}),
            include_groups=False
        ).reset_index()
    topk_matches_df = topk_with_rank[topk_with_rank[topk_key]]
    for k in topk:
        topk[k] = topk_matches_df[topk_matches_df['level_1']+1<=k].shape[0]/num_products

    coverage = {1: 0, 3: 0, 5: 0, 10: 0}
    # For round-trip: check if ANY top-k prediction succeeds per product
    for k in coverage:
        product_success = df.groupby('product_smi').apply(
            lambda x: (x.head(k)[coverage_key] == 1).any(),
            include_groups=False
        )
        coverage[k] = product_success.mean()
    return topk, coverage

def not_necessary_with_latest_evaluation_files(df, true_reactions_path):
    '''
        Add new_topk, round_trip_accuracy, and true_class columns to the dataframe.
        Should not be necessary with latest evaluation files.
    '''
    df['new_topk'] = df.apply(lambda x: compare_reactant_smiles(x['true_reactants'], x['reactant_predictions']), axis=1)
    df['round_trip_accuracy'] = df.apply(lambda x: int(x['product_smi'] in x['round_trip_results']\
                                        or compare_reactant_smiles(x['reactant_predictions'], x['true_reactants'])), axis=1)
    # assign true classes
    #true_df_path = os.path.join(PROJECT_ROOT, 'data', 'schneiderk50k', 'raw_test.csv')
    true_df = pd.read_csv(true_reactions_path)
    true_df['product_smi']= true_df['reactants>reagents>production']\
                                        .apply(lambda x: get_reaction_smiles(x))\
                                        .apply(lambda x: clear_atom_map(x.split('>>')[-1]))
    #true_df['product_smi'] = true_df['reactants>reagents>production'].apply(lambda x: canonicalize_rxn(x, should_remove_atom_map=True).split('>>')[1])
    df = pd.merge(df, true_df[['product_smi', 'class']], on='product_smi', how='left')
    df = df.rename(columns={'class': 'true_class'})
    # apply rxn_insight_info
    df['pred_class'] = df['rxn_insight_info'].apply(lambda x: class_to_idx[eval(x)['CLASS']] if pd.notna(x) else None)
    # df['rxn_insight_class_str'] = df['rxn_insight_CLASS']
    # df['rxn_insight_class'] = df['rxn_insight_class_str'].apply(lambda x: class_to_idx[x])
    return df

def read_batches_from_experiment(experiment_name, start_batch=None, end_batch=None):
    '''
        Read all batches from an experiment.
    '''
    csv_dir = os.path.join(PROJECT_ROOT, experiment_name)
    csvs = sorted([f for f in os.listdir(csv_dir) if f.endswith('.csv')], key=lambda x: int(x.split('_start')[1].split('_end')[0]))
    print(f'====== {len(csvs)} files found ======')
    if start_batch is None:
        start_batch = 0
    if end_batch is None:
        end_batch = len(csvs)+1
    # read all csv and concatenate them in pandas dataframe
    all_files = [os.path.join(csv_dir, f) for f in csvs[start_batch:end_batch+1]]
    print(f'====== processing {len(all_files)} files ======')
    df = pd.concat([pd.read_csv(f) for f in all_files])
    return df

def extract_reaction_smiles(node):
    '''
    Extract the reaction smiles from a node.
    '''
    if isinstance(node, str):
        reactants_smi = node.split('>>')[0]
        product_smi = node.split('>>')[1]
        return product_smi, reactants_smi
    else:
        # NOTE: assumes syntheseus type node
        reactants_smi = '.'.join(r.smiles for r in node.reactants)
        product_smi = '.'.join(p.smiles for p in node.products)
        return product_smi, reactants_smi

def get_ground_truth_for_step(step_idx, ground_truth_route, ground_truth_classes):
    if step_idx < len(ground_truth_classes):
        target_class = ground_truth_classes[step_idx]  
        # NOTE: ground truth routes are given in retro form
        true_reactants = ground_truth_route[step_idx].split('>>')[-1]
        return target_class, true_reactants
    return -1, None

def remove_dative_bonds(reactant_predictions):
    out_reactant_predictions = []
    for reactant_prediction in reactant_predictions:
        mol = Chem.MolFromSmiles(reactant_prediction)
        if mol is None:
            out_reactant_predictions.append(reactant_prediction)
            continue
        for bond in mol.GetBonds():
            if bond.GetBondType() == Chem.BondType.DATIVE:
                bond.SetBondType(Chem.BondType.SINGLE)
        out_reactant_predictions.append(Chem.MolToSmiles(mol))
    return out_reactant_predictions

def remove_dative_bonds_one_molecule(reactant_prediction):
    mol = Chem.MolFromSmiles(reactant_prediction)
    if mol is None:
        return reactant_prediction
    for bond in mol.GetBonds():
        if bond.GetBondType() == Chem.BondType.DATIVE:
            bond.SetBondType(Chem.BondType.SINGLE)
    return Chem.MolToSmiles(mol)

def turn_ids_to_seq(ids, onmt_checkpoint_path):
    vocab = get_vocab_from_trained_model(onmt_checkpoint_path)
    #tokens = [vocab[id] for id in ids if vocab[id] != '<blank>' and vocab[id] != '<unk>' and vocab[id] != '<s>' and vocab[id] != '</s>']
    tokens = [vocab[id] for id in ids]
    return ''.join(tokens)

def turn_results_to_mol_smiles(results):
    all_results_as_mol_smiles = []
    for predictions_for_one_product in results:
        mol_smiles = []
        for prediction in predictions_for_one_product:
            mol_smiles.append(mols_to_str_smiles(prediction.reactants))
        all_results_as_mol_smiles.append(mol_smiles)
    return all_results_as_mol_smiles

def turn_results_to_rxn_smiles(results, product_smi):
    rxn_smiles = []
    for predictions in results:
        rxn_smiles.append(mols_to_str_smiles(predictions.reactants) + '>>' + product_smi)
    return rxn_smiles


def mols_to_str(mols) -> str:
    return " + ".join([mol.smiles for mol in mols])
    #return " + ".join([Chem.MolToSmiles(mol.metadata['rdkit_mol']) for mol in mols])

def mols_to_str_smiles(mols) -> str:
    return ".".join([mol.smiles for mol in mols])

def print_results(results) -> None:
    for idx, prediction in enumerate(results):
        print(f"{idx + 1}: " + mols_to_str(prediction.reactants))

def get_property_score(smi, path, config):
    all_chars_path = os.path.join(PROJECT_ROOT, 
                                   'data', 
                                   config.classifier_guidance.dataset.vocab_file)
    alphabet_size = get_vocab_size_from_config(config)
    property_model = PropertyPredictor(config, alphabet_size)
    checkpoint = torch.load(path, map_location=device)
    property_model.load_state_dict(checkpoint['model_state_dict'])
    property_model.to(device)
    property_model.eval()
    reactants_id = tokenize(smi, all_chars_path)
    reactants_id = reactants_id.unsqueeze(0).to(device)
    with torch.no_grad():
        toxicity_score = property_model(reactants_id)
    toxicity_score = toxicity_score * checkpoint['target_std'] + checkpoint['target_mean']
    return toxicity_score.item()

def compute_properties(route, toxicity_checkpoint_path, yield_checkpoint_path, config):
    route_properties = {'step': [],
                        'synthesis_accessibility': [],
                        'natural_product_likeness': [],
                        'toxicity': [],
                        'yield': []}
    route = list(route)
    fscore = npscorer.readNPModel()
    step = 0
    for n in route:
        # TODO: figure out a way to differentiate the levels in the search tree?
        if isinstance(n, AndNode):
            n.data = {}
            step += 1
        elif isinstance(n, OrNode):
            m = Chem.MolFromSmiles(n.mol.smiles)
            sa_score = sascorer.calculateScore(m)
            np_score = npscorer.scoreMol(m, fscore)
            toxicity_score = get_property_score(n.mol.smiles,  
                                                path=toxicity_checkpoint_path, 
                                                config=config)
            yield_score = get_property_score(n.mol.smiles, 
                                             path=yield_checkpoint_path, 
                                             config=config)
            # TODO: choose which properties to store
            n.data = {'step': step,
                    #   'synthesis_accessibility': sa_score, 
                    #   'natural_product_likeness': np_score,
                      'toxicity': toxicity_score,
                      'yield': yield_score}
            n.data = {}
            # n.data['sa'] = sa_score
            # n.data['np'] = np_score
            # TODO: add toxicity
            route_properties['step'].append(step)
            route_properties['synthesis_accessibility'].append(sa_score)
            route_properties['natural_product_likeness'].append(np_score)
            route_properties['toxicity'].append(toxicity_score)
            route_properties['yield'].append(yield_score)
        # TODO: add yield for and nodes? maybe also forward reaction prediction (NLL under transformer)
    return route_properties

def visualize_routes(config,output_graph, routes, smi_idx):
    for idx, route in enumerate(routes):
        print(f'======= visualizing route {idx + 1}')
        path = os.path.join(PROJECT_ROOT,
                            "experiments", 
                            config.general.experiment_name,
                            f'graphs_for_mol{smi_idx}',
                            f"route_{idx + 1}.pdf")
        print(f'======= saving to {path}')
        visualize_andor(
            output_graph, filename=path, nodes=route
        )

def atom_map_reactions_with_rxnmapper(reactions, 
                                      return_mapping_confidence=False):
    '''
        Add atom mapping to reactions.
    '''
    rxn_mapper = RXNMapper()
    res = rxn_mapper.get_attention_guided_atom_maps(reactions)
    if return_mapping_confidence:
        return res
    else:
        return [atom_mapping_tuple['mapped_rxn'] for atom_mapping_tuple in res]

def get_template_from_reaction(reaction_smiles, cfg):
    '''
        Get the template from the reaction.
    '''
    # reassign atom mapping to reaction smiles
    reaction_smiles = atom_map_reactions_with_rxnmapper([reaction_smiles])[0]
    # use rdkit to get the template
    reactants, products = get_reactant_and_product_from_reaction_smiles(reaction_smiles, return_as_str=True)
    # need to flip reactants and products because they're given in retro form in routes
    reactants, products = products, reactants
    rxn_dict = {
        '_id': 0,
        'ReactionSmiles': reaction_smiles,
        'reactants': reactants,
        'products': products
    }
    template = template_extractor.extract_from_reaction(rxn_dict)
    tmpl = template['reaction_smarts']
    return tmpl

def get_routes_in_json(cfg):
    '''
        Get the routes in json format.
    '''
    json_file = os.path.join(PROJECT_ROOT, 'data', cfg.route_dataset.route_dir, 
                             cfg.route_dataset.processed_dir, f'{cfg.route_dataset.subset}.json')
    if os.path.exists(json_file):
        return json.load(open(json_file))
    original_file = os.path.join(PROJECT_ROOT, 'data', cfg.route_dataset.route_dir, 'original', f'{cfg.route_dataset.subset}.pkl')
    if os.path.exists(original_file):
        routes = pickle.load(open(original_file, 'rb'))
        # save to json
        os.makedirs(os.path.dirname(json_file), exist_ok=True)
        json.dump(routes, open(json_file, 'w'), indent=4)
        return routes
    else:
        raise ValueError(f'No routes found in {json_file} or {original_file}')

def get_reaction_type_from_50k(reaction_smiles, cfg):
    '''
        Get the reaction type from the 50k dataset.
    '''
    uspto_50k = get_50k_dataset(cfg=cfg)
    reactants, products = get_reactant_and_product_from_reaction_smiles(reaction_smiles)
    # need to flip reactants and products because they're given in retro form in routes
    reactants, products = products, reactants
    reactants_set = turn_reaction_side_to_set(reactants)
    products_set = turn_reaction_side_to_set(products)
    reactants_mask = [reactants_set.issubset(x) for x in uspto_50k['ReactantsSet']]
    products_mask = [products_set.issubset(x) for x in uspto_50k['ProductsSet']]
    rxn_df = uspto_50k[np.array(reactants_mask) & np.array(products_mask)]
    return rxn_df.iloc[0]['class'] if len(rxn_df) > 0 else None

def download_50k_dataset(raw_dir):
    '''
        Download the dataset from gln's dropbox url and extract it to the raw directory.
        Then rename the files from raw_train to train, raw_test to test, and raw_val to val.
    '''
    # Convert to direct download URL
    url = "https://www.dropbox.com/scl/fo/swuggv6qf8ombw914yxh8/AEwUgTxowsq2vrnv0D2xRNg/schneider50k?dl=1&rlkey=1ed5tqauj7udn5n2olvw1looi"

    os.makedirs(raw_dir, exist_ok=True)
    download_path = Path(raw_dir) / 'dataset.zip'
    
    try:
        print(f'Downloading dataset from {url} to {download_path}')
        response = requests.get(url, stream=True)
        response.raise_for_status()
        
        # Download with progress
        with open(download_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        
        # Extract
        with zipfile.ZipFile(download_path, 'r') as zip_ref:
            zip_ref.extractall(raw_dir)

        # rename files from raw_train to train, same for test and val
        os.rename(os.path.join(raw_dir, 'raw_train.csv'), os.path.join(raw_dir, 'train.csv'))
        os.rename(os.path.join(raw_dir, 'raw_test.csv'), os.path.join(raw_dir, 'test.csv'))
        os.rename(os.path.join(raw_dir, 'raw_val.csv'), os.path.join(raw_dir, 'val.csv'))
    
        download_path.unlink()  # Remove zip file
        print(f'Dataset downloaded and extracted to {raw_dir}')
    except requests.RequestException as e:
        print(f"Download failed: {e}")
        _manual_download_instructions()
    
def _manual_download_instructions():
    print("Manual download instructions:")
    print("1. Go to the following URL:")
    print("https://www.dropbox.com/scl/fo/swuggv6qf8ombw914yxh8/AEwUgTxowsq2vrnv0D2xRNg/schneider50k?dl=1&rlkey=1ed5tqauj7udn5n2olvw1looi")
    print("2. Download the dataset and save it as 'dataset.zip'")


def get_50k_dataset(cfg):
    ''' 
        Return a pandas dataframe of the 50k dataset with a unified nomenclature (column names).
    '''
    nomenclature_file_path = os.path.join(PROJECT_ROOT, 'data', 
                                          cfg.reaction_dataset.data_dir, 
                                          cfg.reaction_dataset.processed_dir, 
                                          f'{cfg.reaction_dataset.subset}.csv')
    if os.path.exists(nomenclature_file_path):
        df = pd.read_csv(nomenclature_file_path)
    else:
        if cfg.reaction_dataset.processed_dir == 'with_nomenclature':
            df = process_50k_dataset(cfg)
        else:
            raise ValueError(f'Invalid processed directory: {cfg.reaction_dataset.processed_dir}')
    df['ReactantsSet'] = df['ReactantsSet'].apply(eval)
    df['ProductsSet'] = df['ProductsSet'].apply(eval)
    return df

def read_original_50k_dataset(cfg):
    '''
        Read the raw 50k dataset and return a pandas dataframe.
    '''
    data_dir = os.path.join(PROJECT_ROOT, 'data', cfg.reaction_dataset.data_dir, 'original')
    # TODO: if folder does not exist, download from gln link
    available_subsets = [f.split('.')[0] for f in os.listdir(data_dir) if f.endswith('.csv')]
    if cfg.reaction_dataset.subset not in available_subsets:
        raise ValueError(f'Invalid subset: {cfg.reaction_dataset.subset}')
    path = os.path.join(data_dir, f'{cfg.reaction_dataset.subset}.csv')
    df = pd.read_csv(path)
    return df

def reaction_has_invalid_molecules(reactants, products):
    '''
        Check if a reaction has invalid molecules.
    '''
    found_invalid = False
    for molecule in reactants+products:
        if Chem.MolFromSmiles(molecule) is None:
            found_invalid = True
            break
    return found_invalid

def turn_reaction_side_to_set(reaction_side):
    '''
        Turn a reaction side to a set of molecules.
    '''
    if type(reaction_side) != list:
        raise ValueError(f'Invalid reaction side: {reaction_side}')
    return set([clear_atom_map(molecule) for molecule in reaction_side])

def process_50k_dataset(cfg):
    '''
        Process the 50k dataset and return a pandas dataframe with a unified nomenclature (column names).
    '''
    df = read_original_50k_dataset(cfg)
    # process the dataset
    # TODO: might have to use chunk logic here too
    invalid_rows = []
    df['Reactants'] = None
    df['Products'] = None
    df['ReactantsSet'] = None
    df['ProductsSet'] = None
    df['TextMinedYield'] = None
    df['CalculatedYield'] = None
    df['PatentNumber'] = None
    df['ParagraphNum'] = None
    df['Year'] = None
    for row_idx, row in df.iterrows():
        # remove "" around row['ReactionSmiles'] string if present
        reaction_smiles = row['reactants>reagents>production'].strip('"')
        reactants, products = get_reactant_and_product_from_reaction_smiles(reaction_smiles)
        # check molecule validity
        found_invalid = reaction_has_invalid_molecules(reactants, products)
        if found_invalid:
            invalid_rows.append(row_idx)
            continue
        df.at[row_idx, 'reactants>reagents>production'] = reaction_smiles
        df.at[row_idx, 'Reactants'] = '.'.join(reactants)
        df.at[row_idx, 'Products'] = '.'.join(products)
        df.at[row_idx, 'ReactantsSet'] = turn_reaction_side_to_set(reactants)
        df.at[row_idx, 'ProductsSet'] = turn_reaction_side_to_set(products)
    df['reactants>reagents>production'] = df['reactants>reagents>production'].astype(str)
    df = df.drop(invalid_rows, axis=0)
    df = df[['id', 'class', 'reactants>reagents>production', 'Reactants', 'Products', 'ReactantsSet', 'ProductsSet', 'TextMinedYield', 'CalculatedYield', 'PatentNumber', 'ParagraphNum', 'Year']]
    output_dir = os.path.join(PROJECT_ROOT, 'data', cfg.reaction_dataset.data_dir, 'with_nomenclature')
    os.makedirs(output_dir, exist_ok=True)
    df.to_csv(os.path.join(output_dir, f'{cfg.reaction_dataset.subset}.csv'), index=False)

    return df

def get_data_chunk(config):
    applications_path = os.path.join(PROJECT_ROOT, 'data', config.reaction_dataset.path_applications)
    grants_path = os.path.join(PROJECT_ROOT, 'data', config.reaction_dataset.path_grants)
    applications = pd.read_csv(applications_path,
                                sep='\t',
                                skiprows=1,  # Skip the first line
                                header=None,  # No header row
                                names=['ReactionSmiles', 'PatentNumber', 'ParagraphNum', 'Year', 'TextMinedYield', 'CalculatedYield']
                                )
    grants = pd.read_csv(grants_path,
                                sep='\t',
                                skiprows=1,  # Skip the first line
                                header=None,  # No header row
                                names=['ReactionSmiles', 'PatentNumber', 'ParagraphNum', 'Year', 'TextMinedYield', 'CalculatedYield']
                                )
    # cat or merge?
    uspto_original = pd.concat([applications, grants], ignore_index=True)

    return uspto_original[config.reaction_dataset.start_idx:config.reaction_dataset.end_idx]

def parse_yield(yield_str):
    yield_str = str(yield_str)
    if not yield_str: # empty or None
        yield_str = np.nan
    if '>' in yield_str:
        yield_str = yield_str.split('>')[0]
    if '%' in yield_str:
        yield_str = yield_str.split('%')[0]
    if ' to ' in yield_str:
        yield_str = yield_str.split(' to ')[0]
    if '~' in yield_str:
        yield_str = yield_str.split('~')[-1]
    if '<' in yield_str:
        yield_str = yield_str.split('<')[0]
    try:
        return float(yield_str)
    except ValueError:
        raise ValueError(f'Invalid yield: {yield_str}')
    
def get_reaction_smiles(reaction):
    '''
        Get the reaction smiles from the reaction.

        Not sure if checking for mixing '>' and '>>' is necessary. The idea is correct smiles should 
        only have one or the other. Also not that right now we treat reagents as reactants.
    '''
    if '>>' in reaction:
        reactants = reaction.split('>>')[0]
        products = reaction.split('>>')[-1]
        if '>' in reactants or '>' in products:
            raise ValueError(f'Reaction {reaction} contains >')
        rxn = f'{reactants}>>{products}'
    elif '>' in reaction:
        print(f'Reaction {reaction} contains >')
        reactants = reaction.split('>')[0]
        reagents = reaction.split('>')[1]
        products = reaction.split('>')[2]
        if '>>' in reactants or '>>' in reagents or '>>' in products:
            raise ValueError(f'Reaction {reaction} contains >>')
        rxn = f'{reactants}.{reagents}>>{products}'
    else:
        raise ValueError(f'Reaction {reaction} does not contain >> or >')
    return rxn

def load_uspto_50k(cfg):
    path = os.path.join(PROJECT_ROOT, 'data', cfg.reaction_dataset.data_dir, 'original', f'{cfg.reaction_dataset.subset}.csv')
    df = pd.read_csv(path)
    df['rxn_smiles'] = df['reactants>reagents>production'].apply(get_reaction_smiles)
    df['ReactantsSet'] = df['rxn_smiles'].apply(lambda x: x.split('>>')[0].split('.')).apply(turn_reaction_side_to_set)
    df['ProductsSet'] = df['rxn_smiles'].apply(lambda x: x.split('>>')[1].split('.')).apply(turn_reaction_side_to_set)
    return df

def collect_chunks():
    chunks_path = os.path.join(PROJECT_ROOT, 'data', 'uspto_full', 'raw', 'chunks')
    chunks = []
    for file in os.listdir(chunks_path):
        if file.startswith('chunks'):
            file_path = os.path.join(chunks_path, file)
            chunks.append(pd.read_csv(file_path))
    print(f'Found {len(chunks)} chunks.')
    df = pd.concat(chunks, ignore_index=True)
    df['ReactantsSet'] = df['ReactantsSet'].apply(eval)
    df['ProductsSet'] = df['ProductsSet'].apply(eval)
    return df

def get_reactant_and_product_from_reaction_smiles(reaction_smiles, return_as_str=False):
    if '>>' in reaction_smiles:
        reactants = reaction_smiles.split('>>')[0]
        products = reaction_smiles.split('>>')[1]
    elif '>' in reaction_smiles:
        # add reagents to reactants
        reactants = reaction_smiles.split('>')[0]+'.'+reaction_smiles.split('>')[1]
        products = reaction_smiles.split('>')[-1]
    else:
        raise ValueError(f'Invalid reaction smiles: {reaction_smiles}')
    if return_as_str:
        return reactants, products
    else:
        return reactants.split(' ')[0].split('.'), products.split(' ')[0].split('.')

def convert_percentage(x):
    if pd.isna(x):
        return np.nan
    try:
        # if yield is 0 will return 0.0
        return float(str(x).rstrip('%')) / 100.0
    except (ValueError, AttributeError):
        return np.nan

def get_full_length(seq):
    tokens_str, _ = smi_tokenizer(seq)
    tokens = tokens_str.split()
    return len(tokens)

def canonicalize_rxn(rxn, remove_atom_map=True):
    # NOTE; need to remove atom_map always because otherwise it's not canonical
    # TODO: also need to check if rdkit canonicalization cannot happen with atom mapping?
    if '>>' in rxn:
        token = '>>'
    elif '>' in rxn:
        token = '>'
    else:
        raise ValueError(f"Invalid reaction format: {rxn}")

    reactants = rxn.split(token)[0].strip().split('.')
    products = rxn.split(token)[1].strip().split('.')
    
    reactants = [clear_atom_map(r) if remove_atom_map else r for r in reactants]
    products = [clear_atom_map(p) if remove_atom_map else p for p in products]
    
    return f"{'.'.join(reactants)}{token}{'.'.join(products)}"
    
def remove_atom_map(smi):
    mol = Chem.MolFromSmiles(smi)
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(0)
    return mol

def log(msg, quiet=False):
  if (not quiet):
    print(msg)

def request_data(inputs, models, input_type='smiles'):
    log("Enqueueing request "+inputs+", with models :")
    log(models)
    task_id = ''
    r=requests.post("https://tox.charite.de/protox3/src/api_enqueue.php",
                    data={'input_type': input_type,
                          'input': inputs,
                          'requested_data': json.dumps(models)}
                    ) #encode array, the rest are single strings
    if (r.status_code==200): #Data response
        #Data query response. Add response to our task_id_list
        log("Recieved qualified response with id")
        log(r.text)
        task_id = r.text
        #Set up wait time before next query
        if 'Retry-After' in r.headers:
          wait_time=int(r.headers['Retry-After'])+1 #Wait depending on how long the server thinks it needs
        else:
          wait_time=5 #Something went wrong with transmission, just wait 10s by default
        log("Waiting for "+str(wait_time)+" s till next request")
        time.sleep(wait_time)
    elif (r.status_code==429): #Too many requests. Slow down/wait
            log("Server responds : Too many requests. Slowing down query slightly")
            if 'Retry-After' in r.headers:
                wait_time=r.headers['Retry-After']+1 #Wait depending on how long the server thinks it needs
            else:
                wait_time=5 #Something went wrong with transmission, just wait 10s by default
            log("Waiting for "+wait_time+" s till next request")
            time.sleep(wait_time)
    elif (r.status_code==403): #Daily quota exceeded. Aborting operation for now.
            print ("Daily Quota Exceeded. Terminating process.")
            exit(0);
    else: #Server gone away or diƒfferent issues, outputting here
        print ("ERROR : Server issue")
        print (r.status_code,r.reason)

    return task_id

def result_retrieval(task_id, models, smi, outfile):
    data = pd.DataFrame()
    log("Asking for " + task_id)
    tox_data = pd.DataFrame()
    response = pd.DataFrame()
    target_data = pd.DataFrame()
    r = requests.post("https://tox.charite.de/protox3/src/api_retrieve.php",
                        data={'id': task_id})
    if (r.status_code==200): #Data response
        if (r.text==""):
            print ("Warning : Empty response")
        else:
            if "acute_tox" in models[0]:
                tox_data = pd.read_csv("https://tox.charite.de/protox3/csv/"+task_id+"_tox_class.csv", 
                                        sep='\t')
                tox_data.insert(loc=0, column='type', value="acute toxicity")
            if len(set(ALL_MODELS.split()) & set(models[0].split()))>0:
                response = pd.read_csv("https://tox.charite.de/protox3/csv/"+task_id+"_result.csv", 
                                        sep='\t')
                response.drop(response.columns[0], axis=1, inplace=True)
                response.insert(loc=0, column='type', value="toxicity model")
            if "tox_targets" in models[0]:
                print("tox_targets found!")
                target_data = pd.read_csv("https://tox.charite.de/protox3/csv/"+task_id+"_tox_targets.csv", 
                                            sep='\t')
                target_data.insert(loc=0, column='type', value="toxicity target")
            response = pd.concat([tox_data, response, target_data], ignore_index=True)
            response.insert(loc=0, column='input', value=smi)
            data = pd.concat([data, response], ignore_index=True)
    elif (r.status_code==404): #Not found, not computed or finished yet. Do nothing
        print("No response yet. Likely cause: computation unfinished (retrying...)")
    else: #Other codes are not permitted
        print("Unexpected return from server")
        print(r.status_code, r.reason)
        sys.exit();
    #data.to_csv('data.csv', sep='\t', encoding='utf-8')
    reshaped_df = reshape_df(data)
    reshaped_df.to_csv(outfile, encoding='utf-8')
    print ("Completed all operations. Your results are in " + outfile)
    return reshaped_df

def reshape_df(df):
    targets_pred = [target+'_pred' for target in df.Target.unique().tolist()]
    targets_prob = [target+'_prob' for target in df.Target.unique().tolist()]
    target_cols = list(chain.from_iterable(zip(targets_pred, targets_prob)))
    out_df = pd.DataFrame(columns=['smiles']+target_cols)
    row = {'smiles': df['input'].iloc[0]}
    for target_pred in targets_pred:
        target = target_pred.split('_pred')[0]
        row[target_pred] = df[df['Target']==target]['Prediction'].item()
        row[target+'_prob'] = df[df['Target']==target]['Probability'].item()
    out_df.loc[-1] = row
    out_df.index = out_df.index + 1
    out_df = out_df.sort_index()

    return out_df

def run_protox_for_chosen_smiles():
    '''
    Run protox on a single smiles if smiles not in cache.
    '''
    # TODO: load cache df and check if smiles in cache.
    protox_path = os.path.join(PROJECT_ROOT, 'data', 'protox')
    if not os.path.exists(os.path.join(protox_path, 'database.csv')):
        print('No database found, creating new one at ', os.path.join(protox_path, 'database.csv'))
        # create database
        database = pd.DataFrame()
    else:
        print('Database found in ', os.path.join(protox_path, 'database.csv'), ', loading it')
        database = pd.read_csv(os.path.join(protox_path, 'database.csv'), sep='\t')
        print(f'Database contains {len(database)} molecules')

    # read the data we want to process
    # TODO: get this from a config file or smthg
    subset = 'uspto_190_overlap_0'
    with open(os.path.join(protox_path, subset, 'unique_molecules.txt'), 'r') as f:
        smiles = f.read().splitlines()

    # run protox on the data
    smiles_to_process = []  
    for smi in smiles:
        if len(database) > 0 and smi in database['smiles'].values:
            print(f'{smi} already in database')
        else:
            smiles_to_process.append(smi)

    print(f'Processing {len(smiles_to_process)} molecules')
    # run protox on the data
    #smi = '"CCC(=C(C1=CC=CC=C1)C2=CC=C(C=C2)OCCN(C)C)C3=CC=CC=C3","C1=CC=C(C=C1)C(C)C(C)C","CCC(=C(C1=CC=CC=C1)C2=CC=C(C=C2)OCCN(C)C)C3=CC=CC=C3"'
    models = [ALL_MODELS]
    out_path = os.path.join(PROJECT_ROOT, 
                           'data', 
                           'protox',
                           'individual_results')
    os.makedirs(out_path, exist_ok=True)
    
    all_dfs = [database]
    idx = len(database)
    for i, smi in enumerate(smiles_to_process):
        task_df = run_protox_for_smi(smi, models, out_path, idx)
        if task_df is not None:
            all_dfs.append(task_df)
        idx += 1

    all_df = pd.concat(all_dfs, ignore_index=True)
    # update database
    all_df.to_csv(os.path.join(protox_path, f'database.csv'), sep='\t', encoding='utf-8')
    
    return all_df

def run_protox_for_smi(smi, idx):
    '''
        Run protox on a single smiles.
    '''
    models = [ALL_MODELS]
    out_path = os.path.join(PROJECT_ROOT, 
                           'data', 
                           'protox',
                           'individual_results')
    task_id = request_data(smi, models, input_type='smiles')
    if task_id:
        log ("All queries have been enqueued. Starting result retrieval...")
        time.sleep(10)
        outfile = os.path.join(out_path, f'mol{idx}.csv')
        task_df = result_retrieval(task_id, models, smi, outfile)
    else:
        task_df = None
    return task_df
    
# def get_property_score(smi, path, config):
#     all_chars_path = os.path.join(PROJECT_ROOT, 
#                                    'data', 
#                                    config.classifier_guidance.dataset.vocab_file)
#     alphabet_size = get_vocab_size(config)
#     property_model = PropertyPredictor(config, alphabet_size)
#     checkpoint = torch.load(path, map_location=device)
#     property_model.load_state_dict(checkpoint['model_state_dict'])
#     property_model.to(device)
#     property_model.eval()
#     reactants_id = tokenize(smi, all_chars_path)
#     reactants_id = reactants_id.unsqueeze(0).to(device)
#     with torch.no_grad():
#         toxicity_score = property_model(reactants_id)
#     toxicity_score = toxicity_score * checkpoint['target_std'] + checkpoint['target_mean']
#     return toxicity_score.item()

def plot_toxicity_profile(smi, df=None):
    smi = clear_atom_map(smi)
    protox_path = os.path.join(PROJECT_ROOT, 'data', 'protox', 'database.csv')
    if df is None:
        df = pd.read_csv(protox_path, sep='\t')

    def plot_for_endpoint(tox_type, keywords, keywords_idx):
        endpoints =[col for col in df.columns if any(col.startswith(keyword) for keyword in keywords)]
        endpoints = [col.split('_pred')[0] for col in endpoints if col.endswith('pred')]
        data = df_smi[[f"{ep}_pred" for ep in endpoints]].values[0]
        probs = df_smi[[f"{ep}_prob" for ep in endpoints]].values[0]

        # Calculate row and column for 2x3 grid
        row = keywords_idx // 3
        col = keywords_idx % 3
        ax = axes[row, col]
        
        im = ax.scatter(range(len(endpoints)), data, c=probs, cmap='RdYlBu_r', vmin=0, vmax=1)
        ax.set_xticks(range(len(endpoints)))
        ax.set_xticklabels(endpoints, rotation=45, ha='right')
        ax.set_yticks([])
        ax.set_ylim(-0.1,1.1)
        plt.colorbar(im, ax=ax, label='Probability')
        ax.set_title(tox_type)
    
    all_dfs = [df]
    idx = len(df)
    if smi not in df['smiles'].values:
        task_df = run_protox_for_smi(smi, idx)
        if task_df is not None:
            all_dfs.append(task_df)
    all_df = pd.concat(all_dfs, ignore_index=True)
    # update database
    all_df.to_csv(protox_path, sep='\t', encoding='utf-8', index=False)

    df = all_df

    df_smi = df[df['smiles'] == clear_atom_map(smi)]
    # plot the toxicity profile
    if len(df_smi) > 0:
        # Create figure with 2 subplots
        fig, axes = plt.subplots(2, 3, figsize=(15, 5))
        for keywords_idx, (key, value) in enumerate(toxicity_groups.items()):
            plot_for_endpoint(key, value, keywords_idx)

        ax = axes[-1,-1]
        scatter = ax.scatter(df_smi['LD50_pred'], df_smi['LD50_prob'], c=df_smi['tox_class_pred'], 
                    cmap='RdYlGn', # Red (toxic) to Green (non-toxic)
                    norm=plt.Normalize(vmin=1, vmax=6))
        ax.set_xlabel('Acute Toxicity Prediction')
        ax.set_ylabel('Probability')
        ax.set_title('Acute Toxicity')
        ax.set_ylim(-1,101)
        cb = plt.colorbar(scatter, ax=ax)
        cb.set_label('Toxicity Class (1=highest, 6=lowest)')
        plt.tight_layout()
        plt.show()
    else:
        print("No toxicity data found for this molecule")

def obtain_partial_sequences(sequences, values, full_lengths):
    all_partial_sequences = []
    all_values = []
    all_full_lengths = []
    for seq, value, full_length in zip(sequences, values, full_lengths):
        tokens_str, _ = smi_tokenizer(seq.strip())
        tokens = tokens_str.split()
        for i in range(1, len(tokens)):
            all_partial_sequences.append(''.join(tokens[:i]))
            all_values.append(value)
            all_full_lengths.append(full_length)
    assert len(all_partial_sequences)==len(all_values)
    return all_partial_sequences, all_values, all_full_lengths


def clear_atom_map(smi):
    """
    Clear atom map numbers from and canonicalize a SMILES string.

    Args:
        smi (str): SMILES string to clear atom map numbers from.

    Returns:
        str: Canonicalized SMILES string with atom map numbers cleared
    """
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        #print(f'========== Invalid molecule: {smi}')
        return smi
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(0)
    return Chem.CanonSmiles(Chem.MolToSmiles(mol))

def get_product_and_starting_material_smiles(smiles_list):
    pattern = r"\('([^']+)',\s*'([^']+)'\)"
    product_smiles = []
    starting_material_smiles = []
    for line in smiles_list:
        match = re.search(pattern, line.strip())
        if match:
            product_smiles.append(match.group(1))
            starting_material_smiles.append(match.group(2))
    return product_smiles, starting_material_smiles

from collections import deque, defaultdict

def get_targets_and_reaction_type_from_routes(config, starting_material_key='route_most_similar_starting_material'):
    path = os.path.join(PROJECT_ROOT, 'data', config.route_dataset.path)
    with open(path, 'r', encoding='utf-8') as f:
        routes = json.load(f)
    targets = []
    all_ground_truth_data = []
    all_starting_materials = []
    for route in routes:
        target = Chem.MolToSmiles(Chem.MolFromSmiles(route['route'][0].split('>>')[0]))
        starting_material = Chem.MolToSmiles(Chem.MolFromSmiles(route[starting_material_key]))
        targets.append(target)
        all_starting_materials.append(starting_material)
        # Build reaction lookup: product -> (reactants, rxn_type)
        product_to_reaction = {}
        product_to_tanimoto = {}
        for reaction_str in route['route']:
            products_str = reaction_str.split('>>')[0]
            reactants_str = reaction_str.split('>>')[1]
            
            sorted_cano_products = get_sorted_cano_smiles(products_str.split('.'))
            sorted_cano_reactants = get_sorted_cano_smiles(reactants_str.split('.'))
            sorted_cano_reaction = sorted_cano_reactants + '>>' + sorted_cano_products
            
            reaction_type = route['reaction_data'][sorted_cano_reaction]['reaction_type']
            reactants_list = [Chem.MolToSmiles(Chem.MolFromSmiles(r)) for r in reactants_str.split('.')]
            
            product_to_reaction[sorted_cano_products] = (reactants_list, reaction_type)
            product_to_tanimoto[sorted_cano_products] = get_tanimoto(starting_material, sorted_cano_reactants)
        
        # BFS to assign depths
        ground_truth_data = {
            'mol_to_rxn_type': {},
            'mol_to_ground_truth_reactants': {},
            'mol_to_tanimoto': {},
            'depth_to_types': defaultdict(list),
            'depth_to_ground_truth_reactants': defaultdict(list),
            'depth_to_tanimotos': defaultdict(list),
        }
        
        queue = deque([(target, 0)])
        visited = {target}
        
        while queue:
            mol, depth = queue.popleft()
            depth_in_tree = depth*2 # logic for depth: tree is made of interleaved or and and nodes and is 0 indexed.
            
            if mol in product_to_reaction:
                reactants, rxn_type = product_to_reaction[mol]
                tanimoto = product_to_tanimoto[mol]
                ground_truth_data['mol_to_rxn_type'][mol] = rxn_type
                ground_truth_data['mol_to_ground_truth_reactants'][mol] = '.'.join(reactants)
                ground_truth_data['mol_to_tanimoto'][mol] = tanimoto
                ground_truth_data['depth_to_types'][depth_in_tree].append(rxn_type)
                ground_truth_data['depth_to_tanimotos'][depth_in_tree].append(tanimoto)
                # NOTE: used mostly for debugging purposes:
                # can check if the ground truth reactants are in the precursors found in every step of search
                ground_truth_data['depth_to_ground_truth_reactants'][depth_in_tree].append('.'.join(reactants))
                
                for reactant in reactants:
                    if reactant not in visited:
                        visited.add(reactant)
                        queue.append((reactant, depth + 1))
        
        all_ground_truth_data.append(ground_truth_data)
    
    return targets[config.route_dataset.start_idx:config.route_dataset.end_idx], all_ground_truth_data[config.route_dataset.start_idx:config.route_dataset.end_idx], all_starting_materials[config.route_dataset.start_idx:config.route_dataset.end_idx]

def get_smiles_list(config, with_starting_material=False):
    '''
        Get a list of smiles from the config.
    '''
    classes = []
    #print(f'======= using dataset {config.route_dataset.type}')
    if config.route_dataset.type == 'correct_smi':
        product_smiles = ['OCCN1CCC(c2ccc3c(c2)-n2nc(-c4ncnn4CC(F)(F)F)cc2CCO3)CC1']
        starting_material_smiles = [None]
    elif config.route_dataset.type == 'debug_smi':
        assert config.search.dummy_inventory, 'dummy_inventory must be True for debug_smi'
        product_smiles = ['Cc1ccc(-c2ccc(C)cc2)cc1']
        starting_material_smiles = [None]
    elif config.route_dataset.type == 'uspto_hard_reaction_type_guided':
        # Read and process file
        smiles_list, classes = [], []
        path = os.path.join(PROJECT_ROOT, 'data', config.route_dataset.path)
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        for item in data:
            smiles_list.append(item[0])
            classes.append(item[1])
        starting_material_smiles = [None] * len(smiles_list)
    elif config.route_dataset.type == 'uspto_hard':
        path = os.path.join(PROJECT_ROOT, 'data', config.route_dataset.path)
        with open(path, 'r', encoding='utf-8') as f:
            smiles_list = f.readlines()
        print(f'Loaded {len(smiles_list)} from {path}')
        # data is given as (mol1, mol2). Need to return a list of the first mol
        # TODO: can make this better by using regex
        product_smiles, starting_material_smiles = get_product_and_starting_material_smiles(smiles_list)
    elif config.route_dataset.type == 'pistachio_hard':
        path = os.path.join(PROJECT_ROOT, 'data', config.route_dataset.path)
        with open(path, 'r', encoding='utf-8') as f:
            smiles_list = f.readlines()
        product_smiles, starting_material_smiles = get_product_and_starting_material_smiles(smiles_list)
    elif config.route_dataset.type == 'pistachio_reachable':
        path = os.path.join(PROJECT_ROOT, 'data', config.route_dataset.path)
        with open(path, 'r', encoding='utf-8') as f:
            smiles_list = f.readlines()
        product_smiles, starting_material_smiles = get_product_and_starting_material_smiles(smiles_list)
    else:
        raise ValueError(f'Invalid data type: {config.route_dataset.type}')
    classes = classes[config.route_dataset.start_idx:config.route_dataset.end_idx] if len(classes) > 0 else []
    starting_material_smiles = starting_material_smiles[config.route_dataset.start_idx:config.route_dataset.end_idx] if len(starting_material_smiles) > 0 else []
    product_smiles = product_smiles[config.route_dataset.start_idx:config.route_dataset.end_idx] if len(product_smiles) > 0 else []
    if with_starting_material:
        return product_smiles, classes, starting_material_smiles
    else:
        return product_smiles, classes

def get_stratified_samples(df, sampling_rates, min_samples_per_completion_level):

    '''
        Get stratified samples from the dataset.
        Used as one sampling method for training the property predictors.   
    '''
    # NOTE: converting to strata here is a hack to avoid complex code chercking the interval to which a completion ratio belongs
    def get_stratum(completion_ratio):
        if completion_ratio <= 0.2:
            return 0
        elif completion_ratio <= 0.4:
            return 1
        elif completion_ratio <= 0.6:
            return 2
        elif completion_ratio <= 0.8:
            return 3
        else:
            return 4
        
    # Perform stratified sampling
    sampled_data = pd.DataFrame()
    
    for completion_ratio, rate in sampling_rates.items():
        print(f'Sampling {completion_ratio} (stratum {get_stratum(completion_ratio)}) with rate {rate}, min samples per stratum {min_samples_per_completion_level}')
        df['stratum'] = df['completion_ratio'].apply(get_stratum)
        stratum_data = df[df['stratum'] == get_stratum(completion_ratio)] # does this work?
        # Ensure minimum samples per stratum
        samples_to_take = max(
            min_samples_per_completion_level,
            int(len(stratum_data) * rate)
        )
        # Cap samples to available data
        samples_to_take = min(samples_to_take, len(stratum_data))
        print(f'=== samples to take: {samples_to_take}')
        if samples_to_take > 0:
            sampled_stratum = stratum_data.sample(n=samples_to_take, random_state=42)
            sampled_data = pd.concat([sampled_data, sampled_stratum])
        print(f'sampled_data.shape: {sampled_data.shape}')
    
    # Final dataset
    final_df = sampled_data

    return final_df

def get_weighted_samples(df, sampling_fraction=0.01):
    '''
        Get weighted samples from the dataset.
        Used as one sampling method for training the property predictors.
    '''
    # Create weights based on completion ratio (cubic function gives stronger emphasis to complete sequences)
    df['weight'] = df['completion_ratio']**3
    df['weight'] = df['weight'] / df['weight'].sum()  # Normalize
    
    # Calculate how many samples to take
    num_samples = int(len(df) * sampling_fraction)
    num_samples = min(num_samples, len(df))  # Ensure we don't try to sample more than available
    
    # Sample based on weights
    sampled_indices = np.random.choice(
        df.index, 
        size=num_samples, 
        replace=False,
        p=df['weight']
    )
    final_df = df.loc[sampled_indices]

    return final_df
    
def get_samples(df,              sampling_method='stratified',
                sampling_fraction=0.01,
                min_samples_per_completion_level=5,
                # completion ratio < key (e.g. 0.2) will be sampled at rate value (e.g. 0.004)
                sampling_rates = {
                    0.2: 0.004,  # Very incomplete
                    0.4: 0.007,   # Somewhat incomplete
                    0.6: 0.04,   # Half complete
                    0.8: 0.6,   # Mostly complete
                    1.0: 0.8  # Nearly/fully complete
                }):
    ''' 
        Get samples from the dataset to train the property predictors.
    '''
     
    if sampling_method == 'stratified':
        final_df = get_stratified_samples(df, 
                                            sampling_rates, 
                                            min_samples_per_completion_level)
    elif sampling_method == 'weighted':  # NOTE: not used in the paper
        final_df = get_weighted_samples(df, sampling_fraction)
    else:
        raise ValueError(f'Invalid sampling method: {sampling_method}')

    return final_df

def get_partial_sequences_with_completion_ratio_general(sequences, values, lengths, 
                                                completion_lower_limit=0, 
                                                min_length_limit=1,
                                                max_augmentations=5, 
                                                only_augment_complete=True,
                                                print_every = 1000):
    '''
        Get partial sequences with completion ratio.

        NOTE: this is used for other data types (e.g. math expressions), ultimately should be merged with get_partial_sequences_with_completion_ratio
    '''
    # Initialize lists to store data
    all_partial_sequences = []
    all_values = []
    full_lengths = []
    completion_ratios = []  # Track completion ratio for each sequence

    for i, (seq, value, full_length) in enumerate(zip(sequences, values, lengths)):
        if i % print_every == 0:
            print(f'Processing sequence {i} of {len(sequences)}')
        seq = seq.strip()
        # calculate properties
        tokens = seq.split(' ')
        full_length = len(tokens)
        
        # Generate partial sequences with different completion levels
        min_length = int(completion_lower_limit * full_length) if completion_lower_limit < 1 else full_length
        min_length = max(min_length, min_length_limit)
        
        # For each partial sequence length
        for i in range(min_length, full_length + 1):
            partial_seq = ' '.join(tokens[:i])
            completion_ratio = i / full_length
            # For the original sequence at this completion level
            all_partial_sequences.append(partial_seq)
            all_values.append(value)
            full_lengths.append(full_length)
            completion_ratios.append(completion_ratio)
    # Create a DataFrame from the collected data
    df = pd.DataFrame({
        'seq': all_partial_sequences,
        'property': all_values,
        'full_length': full_lengths,
        'completion_ratio': completion_ratios
    })

    return df

def get_partial_sequences_with_completion_ratio(sequences, values, lengths, 
                                                completion_lower_limit=0, 
                                                max_augmentations=5, 
                                                only_augment_complete=True,
                                                print_every = 1000,
                                                starting_material_smiles=None,
                                                starting_material_separator='<unk>',
                                                product_smiles=None):
    '''
        Get partial sequences with completion ratio.
    '''
    # Initialize lists to store data
    all_partial_sequences = []
    all_values = []
    full_lengths = []
    completion_ratios = [] # Track completion ratio for each sequence

    if starting_material_smiles is None:
        starting_material_smiles = [None] * len(sequences)

    for i, (seq, value, full_length, starting_smi, product_smi) in enumerate(zip(sequences, values, lengths, starting_material_smiles, product_smiles)):
        if i % print_every == 0:
            print(f'Processing sequence {i} of {len(sequences)}')
        seq = seq.strip()
        # calculate properties
        m = Chem.MolFromSmiles(seq)

        # Tokenize the original sequence to get full length
        tokens_str, _ = smi_tokenizer(seq) # NOTE: returns tokens from left to right
        tokens = tokens_str.split()
        full_length = len(tokens)
        
        # Generate partial sequences with different completion levels
        min_length = int(completion_lower_limit * full_length) if completion_lower_limit < 1 else full_length
        min_length = max(min_length, 1)
        
        # For each partial sequence length
        for j in range(min_length, full_length + 1):
            partial_seq = ''.join(tokens[:j])
            completion_ratio = j / full_length
            
            # Determine how many augmentations to generate based on completion
            if completion_ratio >= 0.8:  # Nearly complete/complete
                num_augmentations = max_augmentations
            elif completion_ratio >= 0.6:  # Mostly complete
                num_augmentations = max(1, int(max_augmentations * 0.5))
            elif completion_ratio >= 0.4:  # Half complete
                num_augmentations = max(1, int(max_augmentations * 0.2))
            else:  # Less complete sequences
                num_augmentations = 1  # No augmentation, just the original
            
            # For the original sequence at this completion level
            if starting_smi is not None:
                partial_seq += starting_material_separator + starting_smi
            if product_smi is not None:
                partial_seq += '>>' + product_smi
            all_partial_sequences.append(partial_seq)
            all_values.append(value)
            full_lengths.append(full_length)
            completion_ratios.append(completion_ratio)
            
            # Only generate augmentations if num_augmentations > 1
            if num_augmentations > 1 and j == full_length and only_augment_complete:  # Only augment complete sequences
                if '.' in seq:
                    seqs = seq.split('.')
                else:
                    seqs = seq.split('.')
                root_atoms_already_tried = set()
                aug_count = 0
                
                while aug_count < num_augmentations - 1:  # -1 because we already added the original
                    aug_seq = ''
                    root_atoms_already_tried = [set() for _ in seqs]
                    for idx, s in enumerate(seqs):
                        m = Chem.MolFromSmiles(s)
                        root_atom = random.randint(0, m.GetNumAtoms()-1)

                        if len(root_atoms_already_tried[idx]) == m.GetNumAtoms():
                            print(f'All atoms already tried for {s}. Stopped at augmentation {aug_count}')
                            aug_seq_i = s
                        else:
                            while root_atom in root_atoms_already_tried[idx]:
                                root_atom = random.randint(0, m.GetNumAtoms()-1)

                            root_atoms_already_tried[idx].add(root_atom)
                            try:    
                                aug_seq_i = Chem.MolToSmiles(m, rootedAtAtom=root_atom)
                            except:
                                print(f'Error rooted at atom {root_atom} for {seq}')

                        if idx>0: aug_seq += '.'
                        aug_seq += aug_seq_i
                        
                    # Tokenize the augmented sequence
                    aug_tokens_str, _ = smi_tokenizer(aug_seq)
                    aug_tokens = aug_tokens_str.split()
                    
                    # Add the augmented sequence
                    if starting_smi is not None:
                        # generate a random canonicalization for the starting material
                        starting_mol = Chem.MolFromSmiles(starting_smi)
                        num_atoms = Chem.MolFromSmiles(starting_smi).GetNumAtoms()
                        root_atom = random.randint(0, num_atoms-1)
                        starting_smi = Chem.MolToSmiles(starting_mol, rootedAtAtom=root_atom)
                        aug_seq += starting_material_separator + starting_smi
                    if product_smi is not None:
                        product_mol = Chem.MolFromSmiles(product_smi)
                        num_atoms = Chem.MolFromSmiles(product_smi).GetNumAtoms()
                        root_atom = random.randint(0, num_atoms-1)
                        product_smi = Chem.MolToSmiles(product_mol, rootedAtAtom=root_atom)
                        aug_seq += '>>' + product_smi
                    all_partial_sequences.append(aug_seq)
                    all_values.append(value)
                    full_lengths.append(len(aug_tokens))
                    completion_ratios.append(completion_ratio)
                    
                    aug_count += 1

    # Create a DataFrame from the collected data
    df = pd.DataFrame({
        'rxn': all_partial_sequences,
        'property': all_values,
        'full_length': full_lengths,
        'completion_ratio': completion_ratios
    })

    return df

def get_isomeric_smiles_from_pubchem(drug_name):
    '''
        Get the isomeric smiles of a drug from PubChem.
        Used to add isomeric info to the smiles in the Tox dataset.
    '''
    try:
        # Search PubChem by name
        search_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{drug_name}/property/IsomericSMILES/JSON"
        response = requests.get(search_url)
        
        if response.status_code == 200:
            data = response.json()
            #  canonicalize the smiles
            return Chem.MolToSmiles(Chem.MolFromSmiles(data['PropertyTable']['Properties'][0]['IsomericSMILES']))
        else:
            return "Not found in PubChem"
    except:
        return "Error in API request"
    finally:
        # Be nice to the PubChem API
        time.sleep(0.2)

def smi_to_fp(mol_smi, radius=2, fp_size=2048, dtype="int32", as_numpy=True):
    """
    Convert a SMILES string to a Morgan fingerprint.

    Args:
        mol_smi (str): SMILES string to convert to fingerprint
        radius (int): Radius of Morgan fingerprint
        fp_size (int): Size of fingerprint
        dtype (str): Data type of fingerprint

    Returns:
        np.ndarray: Morgan fingerprint
    """
    mol = Chem.MolFromSmiles(mol_smi)
    fp_bit = AllChem.GetMorganFingerprintAsBitVect(
        mol, radius=radius, nBits=fp_size, useChirality=True
    )
    if as_numpy:
        fp = np.empty((1, fp_size), dtype=dtype)
        DataStructs.ConvertToNumpyArray(fp_bit, fp)
        return fp
    return fp_bit

def smi_tokenizer_from_rsmiles(smi):
    #pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    pattern = r"(\[[^\]]+]|He|Li|Be|Ne|Na|Mg|Al|Si|Ar|K|Ca|Sc|Ti|V|Cr|Mn|Fe|Co|Ni|Cu|Zn|Ga|Ge|As|Se|Kr|Rb|Sr|Y|Zr|Nb|Mo|Tc|Ru|Rh|Pd|Ag|"+\
              r"Cd|In|Sn|Sb|Te|Xe|Cs|Ba|La|Ce|Pr|Nd|Pm|Sm|Eu|Gd|Tb|Dy|Ho|Er|Tm|Yb|Lu|Hf|Ta|W|Re|Os|Ir|Pt|Au|Hg|Tl|Pb|Bi|Po|At"+\
              r"Rn|Fr|Ra|Ac|Th|Pa|U|Np|Pu|Am|Cm|Bk|Cf|Es|Fm|Md|No|Lr|Rf|Db|Sg|Bh|Hs|Mt|Ds|Rg|Cn|Nh|Fl|Mc|Lv|Ts|Og|Br|Cl|B|C|N|O|S|P|F|I|H|D|T|b|c|n|o|s|p"+\
              r"|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>|<|\*|\$|\%[0-9]{2}|[0-9])"
    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    assert smi == ''.join(tokens)
    return ' '.join(tokens)

def smi_tokenizer(smi, atom_map_numbers=None):
    '''
        Tokenize a SMILES string.
    '''
    # TODO: check this works for all molecules all the time
    # also should be fine to extend this pattern it wouldn't affect correct output right? 
    # Some notes: important to have r (to parse %, maybe among other things), also need to keep B and C atom labels after Br and Cl
    # TODO: set this automatically somehow
    # atom_pattern = r"He|Li|Be|Ne|Na|Mg|Al|Si|Ar|K|Ca|Sc|Ti|V|Cr|Mn|Fe|Co|Ni|Cu|Zn|Ga|Ge|As|Se|Kr|Rb|Sr|Y|Zr|Nb|Mo|Tc|Ru|Rh|Pd|Ag|"+\
    #           r"Cd|In|Sn|Sb|Te|Xe|Cs|Ba|La|Ce|Pr|Nd|Pm|Sm|Eu|Gd|Tb|Dy|Ho|Er|Tm|Yb|Lu|Hf|Ta|W|Re|Os|Ir|Pt|Au|Hg|Tl|Pb|Bi|Po|At"+\
    #           r"Rn|Fr|Ra|Ac|Th|Pa|U|Np|Pu|Am|Cm|Bk|Cf|Es|Fm|Md|No|Lr|Rf|Db|Sg|Bh|Hs|Mt|Ds|Rg|Cn|Nh|Fl|Mc|Lv|Ts|Og|Br|Cl|B|C|N|O|S|P|F|I|H|D|T|b|c|n|o|s|p"
    
    # pattern = r"(\[[^\]]+]|He|Li|Be|Ne|Na|Mg|Al|Si|Ar|K|Ca|Sc|Ti|V|Cr|Mn|Fe|Co|Ni|Cu|Zn|Ga|Ge|As|Se|Kr|Rb|Sr|Y|Zr|Nb|Mo|Tc|Ru|Rh|Pd|Ag|"+\
    #           r"Cd|In|Sn|Sb|Te|Xe|Cs|Ba|La|Ce|Pr|Nd|Pm|Sm|Eu|Gd|Tb|Dy|Ho|Er|Tm|Yb|Lu|Hf|Ta|W|Re|Os|Ir|Pt|Au|Hg|Tl|Pb|Bi|Po|At"+\
    #           r"Rn|Fr|Ra|Ac|Th|Pa|U|Np|Pu|Am|Cm|Bk|Cf|Es|Fm|Md|No|Lr|Rf|Db|Sg|Bh|Hs|Mt|Ds|Rg|Cn|Nh|Fl|Mc|Lv|Ts|Og|Br|Cl|B|C|N|O|S|P|F|I|H|D|T|b|c|n|o|s|p"+\
    #           r"|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>|<|\*|\$|\%[0-9]{2}|[0-9])"

    

    #pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    #pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    pattern = "(\[[^\]]+]|<s>|</s>|<unk>|<pad>|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    atom_pattern = r"Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p"

    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    
    atom_regex = re.compile(atom_pattern)
    atom_idx = 0
    tokenized_atom_map_numbers = []

    if atom_map_numbers is not None:
        for t in tokens:
            # check if t is an atom:
            if atom_regex.match(t):
                tokenized_atom_map_numbers.append(atom_map_numbers[atom_idx])
                atom_idx += 1
            else:
                tokenized_atom_map_numbers.append(0)

    try:
        assert smi == ''.join(tokens), f"{smi} != {''.join(tokens)} {smi == ''.join(tokens)}"
    except:
        print(smi)
        print(tokens)
        print(''.join(tokens))
        raise

    return ' '.join(tokens), tokenized_atom_map_numbers

def get_full_length(seq):
    '''
        Get the full length of a sequence.
    '''
    tokens_str, _ = smi_tokenizer(seq)
    tokens = tokens_str.split()
    return len(tokens)

def tokenize(smiles, all_chars_path):
    '''
        Convert text to token indices and pad

        Args:
            tokens (list): List of tokens to convert to indices
            start_idx (int): Index of the start token
            end_idx (int): Index of the end token
            unk_idx (int): Index of the unknown token
            char_to_idx (dict): Dictionary mapping characters to indices
    '''

    with open(all_chars_path, 'r', encoding='utf-8') as f:
        all_chars = [line.strip() for line in f]

    char_to_idx = {c: i for i, c in enumerate(all_chars)}
    # Define special token indices
    start_idx = char_to_idx.get('<start>', 1)
    end_idx = char_to_idx.get('<end>', 2)
    unk_idx = char_to_idx.get('<unk>', 0)

    # Convert tokens to indices with unknown token handling
    indices = []
    # Add start token if needed
    indices.append(start_idx)

    tokens_str, _ = smi_tokenizer(smiles)
    tokens = tokens_str.split(" ")
    
    # Process regular tokens
    for token in tokens:
        indices.append(char_to_idx.get(token, unk_idx))
    indices.append(end_idx)
    # Create tensor
    seq_tensor = torch.tensor(indices, dtype=torch.long)

    return seq_tensor

def tokenize_reaction(reaction_smiles):
    """
    Tokenize a single reaction SMILES string.
    
    Args:
        reaction_smiles: SMILES string for the reaction
        
    Returns:
        Set of unique tokens in the reaction
    """
    try:
        # Canonicalize and split into molecules
        canonical_rxn = canonicalize_rxn(reaction_smiles, remove_atom_map=True)
        molecules = canonical_rxn.split('>>')
        
        all_tokens = []
        for molecule in molecules:
            if molecule.strip():  # Skip empty molecules
                tokens_str = smi_tokenizer_from_rsmiles(molecule)
                all_tokens.extend(tokens_str.split(" "))
        
        return set(token for token in all_tokens if token.strip())
        
    except Exception as e:
        print(f'Warning: Failed to tokenize reaction "{reaction_smiles}": {str(e)}')
        return set()

def get_vocab_size_from_file(config):
    '''
        Get the vocabulary size from the vocabulary file.
    '''
    vocab_file = os.path.join(PROJECT_ROOT,
                             'data',
                             "toy_experiment", 
                             config.classifier_guidance.dataset.data_dir,
                             config.classifier_guidance.dataset.vocab_file)
    vocab = [line.strip() for line in open(vocab_file, 'r').readlines()]
    return len(vocab)

def get_vocab_size(config):
    '''
        Get the vocabulary size from the vocabulary file.
    '''
    vocab = get_vocab_from_trained_model(config.classifier_guidance.onmt_checkpoint_path)
    return len(vocab)

def get_vocab_size_from_config(config):
    vocab_file = os.path.join(PROJECT_ROOT,
                             'data',
                             config.classifier_guidance.dataset.vocab_file)
    vocab = [line.strip() for line in open(vocab_file, 'r').readlines()]
    return len(vocab)


def generate_simple_expression(max_num: int = 10) -> str:
    """Generate a simple mathematical expression with numbers and basic operations."""
    operations = ['+', '-']
    op = random.choice(operations)
    a = random.randint(1, max_num)
    b = random.randint(1, max_num)
    return f"{a} {op} {b}"

def generate_complex_expression(max_num: int = 10, max_depth: int = 2) -> str:
    """Generate a complex mathematical expression with multiple operations and parentheses."""
    if max_depth <= 0:
        return str(random.randint(1, max_num))
    
    operations = ['+', '-']
    structures = [
        # Two parenthesized expressions
        lambda: f"( {generate_simple_expression(max_num)} ) {random.choice(operations)} ( {generate_simple_expression(max_num)} )",
        # One parenthesized expression with number
        #lambda: f"( {generate_simple_expression(max_num)} ) {random.choice(operations)} {random.randint(1, max_num)}",
        #lambda: f"{random.randint(1, max_num)} {random.choice(operations)} ( {generate_simple_expression(max_num)} )",
        # Nested expressions
        #lambda: f"( {generate_simple_expression(max_num)} ) {random.choice(operations)} ( {generate_simple_expression(max_num)} )",
    ]
    return random.choice(structures)()

def evaluate_parentheses(expr: str) -> str:
    """Evaluate expressions within parentheses and return simplified form."""
    expr = expr.strip()
    
    # Find innermost parentheses
    while '(' in expr:
        # Find the last opening parenthesis (innermost)
        start = expr.rfind('(')
        if start == -1:
            break
        
        # Find the corresponding closing parenthesis
        end = expr.find(')', start)
        if end == -1:
            break
        
        # Extract the expression inside parentheses
        inner_expr = expr[start+1:end].strip()
        
        # Evaluate the inner expression
        try:
            result = eval(inner_expr)
            # Replace the parenthesized expression with its result
            expr = expr[:start] + str(result) + expr[end+1:]
        except:
            break
    
    return expr.strip()

def simplify_expression(expr: str, degree: int = 1) -> List[str]:
    """
    Simplify a mathematical expression to various degrees.
    
    Args:
        expr: The mathematical expression to simplify
        degree: Maximum degree of simplification (1 = minimal, higher = more simplified)
    
    Returns:
        List of simplified expressions from minimal to maximum degree
    """
    results = []

    # Degree 0: Original expression
    if degree >= 0:
        results.append(expr.strip())
    
    # Find all parenthesized sub-expressions
    paren_matches = list(re.finditer(r'\([^()]+\)', expr.strip()))
    exprs = [expr.strip()]*len(paren_matches)
    
    # Degree 1+: Simplify one parenthesized expression at a time
    # NOTE: updated to only simplify the first expression (using this [:1])
    for current_expr, match in zip(exprs[:1], paren_matches[:1]):
        if len(results) > degree:
            break    
        # Create a copy and simplify one parenthesis
        temp_expr = current_expr
        inner_expr = match.group()[1:-1].strip()  # Remove parentheses
        try:
            result = eval(inner_expr)
            simplified_expr = temp_expr[:match.start()] + str(result) + temp_expr[match.end():]
            simplified_expr = ' '.join(simplified_expr.split())  # Clean up whitespace
            if simplified_expr not in results:
                results.append(simplified_expr)
        except:
            continue
    
    # Continue simplifying until no more parentheses or degree reached
    while '(' in current_expr and len(results) <= degree:
        new_expr = evaluate_parentheses(current_expr)
        if new_expr == current_expr or new_expr in results:
            break
        results.append(new_expr)
        current_expr = new_expr
    
    # Final evaluation if degree allows
    if len(results) <= degree:
        try:
            final_result = str(eval(current_expr))
            if final_result not in results:
                results.append(final_result)
        except:
            pass
    return results[1:degree+2] if degree >= 0 else results

def generate_mathematical_expressions(max_num_expressions: int = 10, 
                                      max_depth: int = 2,
                                      config: dict = None) -> List[str]:
    expressions = []
    simplifications = []
    for _ in range(max_num_expressions):
        expr = generate_complex_expression(max_num=config.classifier_guidance.dataset.max_num, max_depth=max_depth)
        simp = simplify_expression(expr, degree=max_depth)
        expressions.extend([expr]*len(simp))
        simplifications.extend(simp)
    # split expressions and simplifications into train, val, and test
    train_expressions, test_expressions, train_simplifications, test_simplifications = train_test_split(expressions, simplifications, 
                                                                            test_size=config.classifier_guidance.dataset.test_ratio)
    # save to files
    # saving the full data to process the vocab
    with open(os.path.join(PROJECT_ROOT, "data", "toy_experiment", config.classifier_guidance.dataset.data_dir, "data.src"), "w") as f:
        for expr in expressions:
            f.write(expr + '\n')
    with open(os.path.join(PROJECT_ROOT, "data", "toy_experiment", config.classifier_guidance.dataset.data_dir, "data.tgt"), "w") as f:
        for simp in simplifications:
            f.write(simp + '\n')
    with open(os.path.join(PROJECT_ROOT, "data", "toy_experiment", config.classifier_guidance.dataset.data_dir, "train.src"), "w") as f:
        for expr in train_expressions:
            f.write(expr + '\n')
    with open(os.path.join(PROJECT_ROOT, "data", "toy_experiment", config.classifier_guidance.dataset.data_dir, "train.tgt"), "w") as f:
        for simp in train_simplifications:
            f.write(simp + '\n')
    with open(os.path.join(PROJECT_ROOT, "data", "toy_experiment", config.classifier_guidance.dataset.data_dir, "test.src"), "w") as f:
        for expr in test_expressions:
            f.write(expr + '\n')
    with open(os.path.join(PROJECT_ROOT, "data", "toy_experiment", config.classifier_guidance.dataset.data_dir, "test.tgt"), "w") as f:
        for simp in test_simplifications:
            f.write(simp + '\n')

def get_vocab_from_trained_model(model_checkpoint: str):
    # Fix the model path - make it absolute
    checkpoint = torch.load(model_checkpoint, map_location='cpu', weights_only=False)
    vocab = checkpoint['vocab']['src'].base_field.vocab.itos
    return vocab

def tokenize_smiles(smiles):
    tokens_str, _ = smi_tokenizer(smiles)
    tokens = tokens_str.split(' ')
    return set(tokens)

def classifier_data_to_int(src_file: str, lengths_file: str, onmt_checkpoint_path: str):
    # encode the sequences to integers
    with open(src_file, "r") as f:
        src = f.readlines()
    # targets = lengths in toy experiment
    with open(lengths_file, "r") as f:
        trgts = f.readlines()
    with open(lengths_file, "r") as f:
        full_lengths = f.readlines()
    # encode the sequences to integers
    vocab = get_vocab_from_trained_model(onmt_checkpoint_path)
    data_classifier = []
    for s, t, fl in zip(src, trgts, full_lengths):
        ep_encoded = torch.tensor([vocab.index(c.strip()) for c in s.split(' ')])
        src_length = ep_encoded.shape[0]
        ep_padded = torch.cat([torch.tensor(vocab.index('<s>')).unsqueeze(0), \
                            ep_encoded, \
                            torch.tensor(vocab.index('</s>')).unsqueeze(0)], axis=0)
        data_classifier.append((ep_padded, torch.tensor(int(t.strip())), torch.tensor(src_length), torch.tensor(int(fl.strip()))))

    # padd the sequences to the same length
    max_length = max(seq[0].shape[0] for seq in data_classifier)
    for i in range(len(data_classifier)):
        if len(data_classifier[i][0]) < max_length:
            data_classifier[i] = (torch.cat([data_classifier[i][0], 
                                            torch.ones(max_length - data_classifier[i][0].shape[0])*vocab.index('<blank>')], 
                                            axis=0), 
                                data_classifier[i][1],
                                data_classifier[i][2],
                                data_classifier[i][3])
    return data_classifier

# Assuming your data is in a DataFrame with 'sequence' and 'length' columns
def balance_by_undersampling(df, target_col='length'):
    # Count samples per class
    class_counts = df[target_col].value_counts()
    min_count = class_counts.min()
    
    balanced_dfs = []
    for length in class_counts.index:
        class_data = df[df[target_col] == length]
        sampled = class_data.sample(n=min_count, random_state=42)
        balanced_dfs.append(sampled)
    
    return pd.concat(balanced_dfs, ignore_index=True)

def balance_with_ratio(df, target_col='length', ratio_dict=None):
    if ratio_dict is None:
        # Default to equal representation
        unique_classes = df[target_col].unique()
        ratio_dict = {cls: 1 for cls in unique_classes}
    
    # Find the base count (smallest desired count)
    min_samples = df[target_col].value_counts().min()
    base_count = min_samples // max(ratio_dict.values())
    
    balanced_dfs = []
    for length, ratio in ratio_dict.items():
        class_data = df[df[target_col] == length]
        target_count = base_count * ratio
        
        if len(class_data) >= target_count:
            sampled = class_data.sample(n=int(target_count), random_state=42)
        else:
            sampled = class_data  # Use all available if not enough
        
        balanced_dfs.append(sampled)
    
    return pd.concat(balanced_dfs, ignore_index=True)

def generate_vocab(vocab_file: str, src_file: str, tgt_file: str):
    with open(src_file, "r") as f:
        src = f.readlines()
    with open(tgt_file, "r") as f:
        tgt = f.readlines()
    data = src + tgt
    special_tokens = ['<unk>', '<blank>', '<s>', '</s>']
    #math_tokens = ['(', ')', '+', '*', '-']
    #vocab = special_tokens + math_tokens + [str(i) for i in range(-1001, 1001)]
    vocab = special_tokens + [e for e in set([t.strip() for seq in data for t in seq.split(' ')])]

    with open(vocab_file, "w") as f:
        for token in vocab:
            f.write(token + '\n')
    return vocab

def generate_classifier_data(tgt_file: str, 
                             max_augmentations: int = 5, 
                             completion_lower_limit: int = 0,
                             min_length_limit: int = 1,
                             config: dict = None):
    with open(tgt_file, "r") as f:
        raw_tgt = f.readlines()
    num_tokens = [len(t.split(' ')) for t in raw_tgt]
    # generate partially completed sequences
    df = get_partial_sequences_with_completion_ratio_general(sequences=raw_tgt, 
                                                            values=raw_tgt, 
                                                            lengths=num_tokens, 
                                                            completion_lower_limit=completion_lower_limit, 
                                                            min_length_limit=min_length_limit,
                                                            max_augmentations=max_augmentations, 
                                                            only_augment_complete=True,
                                                            print_every = 1)
    
    df = balance_by_undersampling(df, target_col='full_length')
    # Usage: Keep length 7 at 2x the count of length 3
    #df = balance_with_ratio(df, target_col='full_length', ratio_dict={3: 1, 7: 3})

    # subsample the sequences of length 7
    src = df['seq'].values
    lengths = df['full_length'].values
    # save the data
    with open(os.path.join(PROJECT_ROOT, "data", "toy_experiment", config.classifier_guidance.dataset.data_dir, "classifier_train.src"), "w") as f:
        for expr in src:
            f.write(expr + '\n')
    with open(os.path.join(PROJECT_ROOT, "data", "toy_experiment", config.classifier_guidance.dataset.data_dir, "classifier_train.tgt"), "w") as f:
        for length in lengths:
            f.write(str(length) + '\n')


# def prepare_data_for_classifier(raw_tgt):
#     '''
#         prepare data for classifier.

#             raw_src: list of strings
#             raw_tgt: list of strings
#     '''
#     vocab = get_vocab()
#     num_tokens = [len(t.split(' ')) for t in raw_tgt]
#     df = get_partial_sequences_with_completion_ratio_general(sequences=raw_tgt, 
#                                                                 values=raw_tgt, 
#                                                                 lengths=num_tokens, 
#                                                                 completion_lower_limit=0, 
#                                                                 max_augmentations=0, 
#                                                                 only_augment_complete=True,
#                                                                 print_every = 1)
#     src = df['seq'].values
#     tgt = df['full_length'].values
#     data_classifier = []
#     for t, n in zip(src, tgt):
#         ep_encoded = torch.tensor([vocab.index(c) for c in t.split(' ')])
#         ep_padded = torch.cat([torch.tensor(vocab.index('<s>')).unsqueeze(0), ep_encoded,
#                                torch.tensor(vocab.index('</s>')).unsqueeze(0)], axis=0)
#         data_classifier.append((ep_padded, torch.tensor(n)))
#     # padd the sequences to the same length
#     max_length = max(seq[0].shape[0] for seq in data_classifier)
#     for i in range(len(data_classifier)):
#         if len(data_classifier[i][0]) < max_length:
#             data_classifier[i] = (torch.cat([data_classifier[i][0], 
#                                             torch.ones(max_length - data_classifier[i][0].shape[0])*vocab.index('<pad>')], 
#                                             axis=0), 
#                                   data_classifier[i][1])
#     return data_classifier
