import copy
import os
import json
import pandas as pd

base_path = ''

def process_dataset(data_df):

    data_df = data_df[data_df['question_type'] != '2_hop']  # Remove two-hop queries
    data_df = data_df[data_df['question_type'] != 'templated'] # Remove templated queries, which do not have question_original
    new_rows = []
    # Take the first image only.
    for i, data_row in data_df.iterrows():
        dataset_image_ids = data_row['dataset_image_ids'].split('|')
        data_row['dataset_image_ids'] = dataset_image_ids[0]

        new_rows.append(data_row)

    new_data_df = pd.DataFrame(new_rows, columns=data_df.columns)

    return new_data_df


# Each text query contains up to 5 different images, resulting in ~x5 more queries.
# However, due to the computational resources, we only stick to the unique text queries, (and for the fair comparison with the text-only baseline)
# Also, for contrastive loss, the 2-hop type query is ambiguous to handle. Hence, we remove those in our training.
if __name__ == '__main__':

    train_df = pd.read_csv(os.path.join(base_path, 'train.csv'))
    val_df = pd.read_csv(os.path.join(base_path, 'val.csv'))
    test_df = pd.read_csv(os.path.join(base_path, 'test.csv'))

    new_train_df = process_dataset(train_df)
    new_val_df = process_dataset(val_df)
    new_test_df = process_dataset(test_df)

    new_train_df.to_csv(os.path.join(base_path, 'train_clean.csv'), index=False)
    new_val_df.to_csv(os.path.join(base_path, 'val_clean.csv'), index=False)
    new_test_df.to_csv(os.path.join(base_path, 'test_clean.csv'), index=False)