import copy
import os
import json
import pandas as pd
from tqdm import tqdm
import pandas as pd


base_path = ''

def extract_subset(ratio, valid_doc_urls, split='train'):

    query_df = pd.read_csv(os.path.join(base_path, f'{split}_section.csv'))

    # Remove Wikipedia URLs that are not valid.
    filtered_query_df = query_df[query_df['wikipedia_url'].isin(valid_doc_urls)]

    if split == 'test':
        filtered_query_df.to_csv(os.path.join(base_path, f'{split}_clean.csv'), index=False)
    else:
        # Ensure all 'wikipedia_url' items are selected at least once.
        grouped = filtered_query_df.groupby('wikipedia_url')
        
        sampled_groups = grouped.apply(lambda x: x.sample(max(int(ratio * len(x)), 1))).reset_index(drop=True)
        sampled_groups.to_csv(os.path.join(base_path, f'{split}_clean.csv'), index=False)


# Here, we will extract subset of the Infoseek dataset, since the dataset size is too large considering our resource limit.
# Note that there are only about 6k documents that are actually used for the queries.
# Shrink the dataset size while maintaining the same number of valid documents (6k).
if __name__ == '__main__':

    valid_entities = json.load(open(os.path.join(base_path, 'infoseek_wikipedia', 'infoseek_valid_entities.json'), 'r'))
    valid_doc_urls = [f"https://en.wikipedia.org/wiki/{entity.replace(' ', '_')}" for entity in valid_entities]

    ratio = 0.25

    extract_subset(ratio, valid_doc_urls, 'train')
    extract_subset(ratio, valid_doc_urls, 'val')
    extract_subset(ratio, valid_doc_urls, 'test')

    print("Done!")