'''
try sampling and evaluating in batches
'''
import os
import gc
from unittest.mock import MagicMock
import sys
import time
from pathlib import Path
import hydra
import pandas as pd
import torch
import random
import numpy as np

sys.path.insert(0, str(Path(__file__).parent.parent))
# Mock the problematic RDKit drawing modules
sys.modules['rdkit.Chem.Draw.rdMolDraw2D'] = MagicMock()
sys.modules['rdkit.Chem.Draw'] = MagicMock()

from multiguide.helpers import PROJECT_ROOT, set_seed
from multiguide.dataset.helpers import get_batch
from multiguide.dataset.helpers import parse_batch_to_reaction_data
from multiguide.evaluation.helpers import get_results_for_one_batch, define_single_step_model, evaluate_results_for_one_batch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_rows_per_batch(reaction_batch, results):
    '''
        This function gets the rows per batch.
    '''
    rows_per_batch = []
    for product_idx, product_results in enumerate(results):
        for sample_idx, sample_result in enumerate(product_results):
            row = {
                'product_smi': reaction_batch[product_idx].product,
                'true_reactants': reaction_batch[product_idx].reactants,
                'true_class': reaction_batch[product_idx].class_idx,
                'true_most_similar_reactants_similarity': reaction_batch[product_idx].most_similar_reactants_similarity,
                'true_least_similar_reactants_similarity': reaction_batch[product_idx].least_similar_reactants_similarity,
                'true_most_similar_reactants': reaction_batch[product_idx].most_similar_reactants,
                'true_least_similar_reactants': reaction_batch[product_idx].least_similar_reactants,
                'true_similarity_to_target': reaction_batch[product_idx].similarity_to_target, # NOTE: only available in routes data
                'conditional_starting_material': reaction_batch[product_idx].conditional_starting_material,
                'conditional_target': reaction_batch[product_idx].conditional_target,
                'original_target': reaction_batch[product_idx].original_target,
                'original_starting_material': reaction_batch[product_idx].original_starting_material,
                'reactant_predictions': sample_result,
                'product_idx': product_idx,
                'sample_index': sample_idx
            }
            rows_per_batch.append(row)
    return rows_per_batch

def save_df(df, config, out_file_name):
    '''
        This function saves the dataframe.
    '''
    out_dir = os.path.join(
        PROJECT_ROOT,
        'experiments', 
        config.general.experiment_group,
        config.general.experiment_params,
        config.general.experiment_name
    )
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(
        out_dir,
        out_file_name
    )
    df.to_csv(out_path, index=False)
    return out_path

@hydra.main(config_path='../configs', config_name='config.yaml')
def sample_in_batch(config):
    set_seed(config.general.seed)
    print(f'======== Seed: {config.general.seed}')
    print(f'======== Start idx: {config.single_step_evaluation.start_idx}')
    print(f'======== End idx: {config.single_step_evaluation.end_idx}')
    # max_num_products = 50
    start_idx = config.single_step_evaluation.start_idx
    #last_end_idx = config.single_step_evaluation.end_idx
    end_idx = config.single_step_evaluation.end_idx
    #if end_idx-start_idx >= 50:
    # while end_idx <= last_end_idx:
    #     end_idx = start_idx + max_num_products
    batch = get_batch(config)
    reactions = parse_batch_to_reaction_data(
        batch,
        start_idx
    )
    model = define_single_step_model(config)
    start_time = time.time()
    results = get_results_for_one_batch(
        model,
        config,
        reactions
    )
    print(f'======== Sampling time: {time.time() - start_time} seconds')
    # results = [[num_samples_per_product]*num_products]
    all_rows = get_rows_per_batch(reactions, results)
    # get dataframe
    df = pd.DataFrame(all_rows)
    out_path = save_df(df, config, f'sampled_start{start_idx}_end{end_idx}.csv')
    print(f'======== Sampled df saved to {out_path}')
    # evaluate dataframe
    # sampled_dir = os.path.join(
    #     PROJECT_ROOT,
    #     'experiments', 
    #     config.general.experiment_group,
    #     config.general.experiment_params,
    #     config.general.experiment_name
    # )
    # os.makedirs(sampled_dir, exist_ok=True)
    # sampled_path = os.path.join(
    #     sampled_dir,
    #     f'sampled_start{start_idx}_end{end_idx}.csv'
    # )
    # df = pd.read_csv(sampled_path)
    #print(f'======== Loaded df from {sampled_path}')
    start_time = time.time()
    df = evaluate_results_for_one_batch(df, config)
    print(f'======== Evaluating time: {time.time() - start_time} seconds')
    out_path = save_df(df, config, f'evaluated_start{start_idx}_end{end_idx}.csv')
    print(f'======== Evaluated df saved to {out_path}')
    #start_idx = end_idx
    torch.cuda.empty_cache()
    gc.collect()

if __name__ == '__main__':
    sample_in_batch()