import pandas as pd

combined_data = pd.read_parquet("PATH/tdc_dataset.parquet")

logp_threshold = 2.5


def format_input_txt(examples):
    input_txt = []
    for i in range(len(examples)):
        if examples['dataset_id'].iloc[i] == 'oled':
            input_txt.append(
                f"[WAVELENGTH]{examples['wavelength'].iloc[i]}[/WAVELENGTH]"
                f"[F_OSC]{examples['f_osc'].iloc[i]}[/F_OSC][SEP]"
                f"[START_SMILES]{examples['smiles'].iloc[i]}[END_SMILES]"
            )
        elif examples['dataset_id'].iloc[i] == 'tdc':
            input_txt.append(
                f"</s>[QED]{round(examples['QED'].iloc[i], 2)}[/QED][LOGP]{round(examples['LOGP'].iloc[i], 2)}[/LOGP][SEP][START_SMILES]"
            )
        else:
            raise ValueError(f"Unexpected Dataset ID: {examples['dataset_id'].iloc[i]}")
    
    examples['input_txt'] = input_txt
    return examples

        
combined_data = format_input_txt(combined_data)
        
low_logp_stimuli = combined_data[combined_data['LOGP'] <= logp_threshold]
high_logp_stimuli = combined_data[combined_data['LOGP'] > logp_threshold]

low_logp_stimuli.to_parquet(f"PATH/concept_representation_alignment/stimuli_dataset/logp/low_logp_{logp_threshold}_tdc_data_new.parquet")
high_logp_stimuli.to_parquet(f"PATH/concept_representation_alignment/stimuli_dataset/logp/high_logp_{logp_threshold}_tdc_data_new.parquet")