import os
import sys
import re
import hydra
import pandas as pd

from rdkit import Chem
from setup_path import *
from multiguide.dataset.helpers import class_to_idx, get_rxn_insight_info
from multiguide.helpers import PROJECT_ROOT
from multiguide.dataset.helpers import compare_reactant_smiles, assign_similarity

def deduplicate_exact_matches(df):
    # Create a mask for duplicates only where is_exact_match=1
    mask = (df['is_exact_match'] == 1) & df.duplicated(
        subset=['product', 'true_reactants'], 
        keep='last'
    )

    # Drop those rows
    out_df = df[~mask]
    return out_df

def extract_start_idx_from_old_eval_file(old_eval_file):
    match = re.search(r's(\d+)', old_eval_file)
    if match:
        number = match.group(1)  # Returns "0"
        return int(number)
    return None

@hydra.main(config_path='../configs', config_name='config.yaml')
def translate_old_to_new_eval_files(config):
    experiment_dir = os.path.join(
        PROJECT_ROOT, 
        'experiments', 
        'single_step_50k', 
        'no_guidance', 
        '50k_seedrandom_modeldiffalign_steeredfalse_guidance0_length0_results_100_candidates72_time20251022_154711'
    )
    files = [f for f in os.listdir(experiment_dir) if f.startswith('eval_epoch') and 'resorted' in f and 'new' not in f]
    for f in files:
        old_eval_file = os.path.join(experiment_dir, f)
        translate_file(config, experiment_dir, old_eval_file)

def translate_file(config, experiment_dir, old_eval_file, num_conditions_per_file=50):
    # old: eval_epoch720_steps100_resorted_0.9_s0_new
    # new:sampled_start4600_end4800
    start_idx = extract_start_idx_from_old_eval_file(old_eval_file)
    if start_idx is None:
        raise ValueError(f"Could not extract start index from old eval file: {old_eval_file}")
    start_idx = start_idx * num_conditions_per_file
    end_idx = start_idx + num_conditions_per_file
    old_eval_path = os.path.join(experiment_dir, old_eval_file)
    #new_eval_file = old_eval_file.split('.txt')[0] + '_new.csv'
    new_eval_file = f'sampled_start{start_idx}_end{end_idx}.csv'
    new_eval_path = os.path.join(experiment_dir, new_eval_file)

    old_eval = open(old_eval_path, 'r').read()

    # Regex to match entire blocks
    block_pattern = r'\(cond\s+\d+\)\s+[^>]+>>[^:]+:(?:\n\s+\([^\)]+\))*'
    block_pattern = r'\(cond\s+(\d+)\)\s+(.+?>>.+?):((?:\n\s+\(.+?,\s*\[.+?\]\))+)'
    data_line_pattern = r"\('(.+?)',\s*\[(.+?)\]\)"
    # block_pattern = r'\(cond\s+(\d+)\)\s+(.+?>>.+?):((?:\n\s+\(.+?\))+)'
    # data_line_pattern = r"\('(.+?)',\s*\[(.+?)\]\)"

    # data = "    ('O=CC(Br)=CN=CNC1CC2=CC=CC=C2C1>>O=CC1=CN=CN1C1CC2=CC=CC=C2C1', [0.23111647367477417, 0.2306087166070938, -0.0005096630193293095, 20, 0.07795154919754205])"
    # data_matches = re.findall(data_line_pattern, data)
    # print(f'len(data_matches): {len(data_matches)}')
    # exit()
    blocks = re.findall(block_pattern, old_eval)
    records = []

    for cond_num, true_rxn, data_lines in blocks:
        true_product = true_rxn.split('>>')[1]
        true_reactants = true_rxn.split('>>')[0]
        data_matches = re.findall(data_line_pattern, data_lines)
        print(f'len(data_matches): {len(data_matches)}, for cond_num: {cond_num}')
        true_product_cano = Chem.MolToSmiles(Chem.MolFromSmiles(true_product))
        true_reactants_cano = Chem.MolToSmiles(Chem.MolFromSmiles(true_reactants))
        rxn_insight_info = get_rxn_insight_info(rxn_smi=true_reactants_cano+'>>'+true_product_cano)
        true_class = class_to_idx[rxn_insight_info['CLASS']] if rxn_insight_info is not None else -1
        for predicted_rxn, metrics_str in data_matches:
            metrics = [float(x.strip()) for x in metrics_str.split(',')]
            predicted_product = predicted_rxn.split('>>')[1]
            predicted_reactants = predicted_rxn.split('>>')[0]
            if Chem.MolFromSmiles(predicted_reactants) is None:
                continue
            predicted_reactants_cano = Chem.MolToSmiles(Chem.MolFromSmiles(predicted_reactants))
            assert predicted_product == true_product, f"Predicted product {predicted_product} does not match true product {true_product}"
            is_exact_match = compare_reactant_smiles(true_reactants, predicted_reactants)
            # NOTE: could add more metrics here like reaction naming, round trip accuracy, etc.
            # NOTE: hmm looks like these old evals are not well canonicalized, there are still some duplicates.
            # TODO: deduplicate before sharing final file... can add the count but not sure what to do about the other metrics.
            # NOTE: for now keep the first duplicate only
            # compute true reaction type
            records.append({
                'condition': int(cond_num),
                #'true_reaction': true_rxn,
                "sorted_cano_products": true_product,
                "sorted_cano_reactants": true_reactants_cano,
                "true_reactants_original": true_reactants,
                "true_class": true_class,
                "original_target": true_product,
                "reactant_predictions_original": predicted_reactants,
                "reactant_predictions": predicted_reactants_cano,
                "topk": is_exact_match,
                # NOTE: will probably drop these metrics in latest evaluation script
                # x["elbo"], x["loss_t"], x["loss_0"], x["count"], x["weighted_prob"]
                'elbo': metrics[0],
                'loss_t': metrics[1],
                'loss_0': metrics[2],
                'count': metrics[3],
                'weighted_prob': metrics[4]
            })

    df = pd.DataFrame(records)
    df = assign_similarity(df, config)
    df['original_starting_material'] = df['most_similar_reactants']
    df['product_smi'] = df['sorted_cano_products']
    df['true_reactants'] = df['sorted_cano_reactants']
    #df = deduplicate_exact_matches(df)d
    #df = df.drop_duplicates(subset=['product', 'true_reactants_canonicalized', 'pred_reactants_canonicalized'], keep='first')
    df.to_csv(new_eval_path, index=False)
    print(f"Saved to {new_eval_path}")

if __name__ == "__main__":
    # old_eval_file = "7ck/eval_epoch720_steps100_resorted_0.9_cond4992_sampercond100_test_lam0.9.txt"
    # old_eval_file = "7ck/eval_epoch760_steps100_resorted_0.9_cond4992_sampercond100_val_lam0.9.txt"
    #experiment_dir = os.path.join(PROJECT_ROOT, 'experiments', 'testing_sample_array_job_20251019_000836')
    translate_old_to_new_eval_files()
        #exit()