import os
import pandas as pd
import json
from tqdm import tqdm

base_path = ''


def process_dataset(split='train'):
    query_df = pd.read_json(os.path.join(base_path, 'infoseek_data', f'infoseek_{split}.jsonl'), lines=True)
    query_kb_df = pd.read_json(os.path.join(base_path, 'infoseek_data', f'infoseek_{split}_withkb.jsonl'), lines=True)

    urls, titles = [], []

    query_kb_df = query_kb_df.sort_values('data_id')
    for kb_row in tqdm(query_kb_df.itertuples(index=False), total=len(query_kb_df)):
        url = f"https://en.wikipedia.org/wiki/{kb_row.entity_text.replace(' ', '_')}"
        urls.append(url)
        titles.append(kb_row.entity_text)        
    
    query_df['wikipedia_url'] = urls
    query_df['wikipedia_title'] = titles
    query_df['dataset_name'] = ['Infoseek'] * len(query_df)
    query_df = query_df.rename(columns = {'image_id' : 'dataset_image_ids'})
    
    mapping_func = {}
    possible_extensions = ['jpg', 'JPEG', 'png']
    for query in tqdm(query_df.itertuples(index=False), total=len(query_df)):
        
        found = False
        for ext in possible_extensions:
            image_path = os.path.join(base_path, f'infoseek_images/infoseek_{split}_images/{query.dataset_image_ids}.{ext}')
            if os.path.exists(image_path):
                mapping_func[query.dataset_image_ids] = f'infoseek/infoseek_images/infoseek_{split}_images/{query.dataset_image_ids}.{ext}'
                found = True
                break
        
        if not found:
            raise ValueError(f"infoseek/infoseek_images/infoseek_{split}_images/{query.dataset_image_ids} does not have valid extension.")

    # Since Infoseek does not provide the test dataset, we use the validation dataset as test dataset
    # and split the train dataset into train and validation dataset.
    if split == 'train':

        # Randomly shuffle the train queries before split.
        query_df = query_df.sample(frac=1).reset_index(drop=True)

        num_queries = len(query_df)
        num_train_queries = int(num_queries * 0.9)
        num_val_queries = num_queries - num_train_queries
        
        train_query_df = query_df.iloc[:num_train_queries]
        train_query_df['infoseek_split'] = ['train'] * num_train_queries

        val_query_df = query_df.iloc[-num_val_queries:]
        val_query_df['infoseek_split'] = ['val'] * num_val_queries

        train_query_df.to_csv(os.path.join(base_path, 'train_section.csv'), index=False)
        val_query_df.to_csv(os.path.join(base_path, 'val_section.csv'), index=False)

    if split =='val':
        query_df['infoseek_split'] = ['test'] * len(query_df)
        query_df.to_csv(os.path.join(base_path, 'test_section.csv'), index=False)

    return mapping_func

# We merge the infoseek_{split}.jsonl and infoseek_{split}_withkb.jsonl so that the final infoseek_{split}_section.csv
# contains the wikipedia_url, resembling the encyclopedic_vqa query dataset.
if __name__ == '__main__':

    dataset_id_to_path = {}

    train_val_dataset_id_to_path = process_dataset(split='train')
    test_dataset_id_to_path = process_dataset(split='val')

    dataset_id_to_path.update(train_val_dataset_id_to_path)
    dataset_id_to_path.update(test_dataset_id_to_path)

    with open(os.path.join(base_path, 'dataset_id_to_path.json'), 'w') as f:
        json.dump(dataset_id_to_path, f)

    print("Done!")