import pandas as pd

combined_data = pd.read_parquet("PATH/tdc_dataset.parquet")

qed_threshold = 0.8

print(combined_data.head())

def format_input_txt(examples):
    input_txt = []
    for i in range(len(examples)):
        if examples['dataset_id'].iloc[i] == 'oled':
            input_txt.append(
                f"<s>[WAVELENGTH]{examples['wavelength'].iloc[i]}[/WAVELENGTH]"
                f"[F_OSC]{examples['f_osc'].iloc[i]}[/F_OSC][SEP]"
                f"[START_SMILES]{examples['smiles'].iloc[i]}[END_SMILES]"
            )
        elif examples['dataset_id'].iloc[i] == 'tdc':
            input_txt.append(
                f"</s>[QED]{round(examples['QED'].iloc[i], 2)}[/QED][LOGP]{round(examples['LOGP'].iloc[i], 2)}[/LOGP][START_SMILES]"
            )
        else:
            raise ValueError(f"Unexpected Dataset ID: {examples['dataset_id'].iloc[i]}")
    
    examples['input_txt'] = input_txt
    return examples

        
combined_data = format_input_txt(combined_data)
 
low_qed_stimuli = combined_data[combined_data['QED'] <= qed_threshold]
high_qed_stimuli = combined_data[combined_data['QED'] > qed_threshold]


low_qed_stimuli.to_parquet(f"PATH/concept_representation_alignment/stimuli_dataset/qed/low_qed_{qed_threshold}_tdc_data_new.parquet")
high_qed_stimuli.to_parquet(f"PATH/concept_representation_alignment/stimuli_dataset/qed/high_qed_{qed_threshold}_tdc_data_new.parquet")