from datasets import load_dataset, DatasetDict
import json
import os
import aiohttp
# Load NQ and HotpotQA datasets (train/val/test)
nq = DatasetDict({
    "train": load_dataset("kilt_tasks", "nq", split="train[:20000]"),
    "validation": load_dataset("kilt_tasks", "nq", split="validation[:1000]"),
    # "test": load_dataset("kilt_tasks", "nq", split="test")
})

hotpot = DatasetDict({
    "train": load_dataset("kilt_tasks", "hotpotqa", split="train[:20000]"),
    "validation": load_dataset("kilt_tasks", "hotpotqa", split="validation[:1000]"),
    # "test": load_dataset("kilt_tasks", "hotpotqa", split="test")
})

# Extract unique Wikipedia IDs from each dataset
def collect_wikipedia_ids(dataset_dict):
    wiki_ids = set()
    for split in dataset_dict:
        for example in dataset_dict[split]:
            for out in example.get("output", []):
                for prov in out.get("provenance", []):
                    wid = prov.get("wikipedia_id")
                    if wid:
                        wiki_ids.add(str(wid))
    return sorted(wiki_ids)

nq_ids = collect_wikipedia_ids(nq)
hotpot_ids = collect_wikipedia_ids(hotpot)

# Save as JSON
os.makedirs("wiki_id_output", exist_ok=True)

with open("wiki_id_output/nq_wikipedia_ids.json", "w") as f:
    json.dump(nq_ids, f, indent=2)

with open("wiki_id_output/hotpot_wikipedia_ids.json", "w") as f:
    json.dump(hotpot_ids, f, indent=2)

local_path = "./kilt_wikipedia" 

# Load the dataset from the local files
wiki = load_dataset(local_path, split="full", storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600)}})

# Load filtered Wikipedia ID lists (you've created these earlier)
with open("wiki_id_output/nq_wikipedia_ids.json", "r") as f:
    nq_ids = set(json.load(f))
with open("wiki_id_output/hotpot_wikipedia_ids.json", "r") as f:
    hotpot_ids = set(json.load(f))

# Filter Wikipedia for NQ
nq_filtered = wiki.filter(lambda x: str(x["wikipedia_id"]) in nq_ids)

# Filter Wikipedia for HotpotQA
hotpot_filtered = wiki.filter(lambda x: str(x["wikipedia_id"]) in hotpot_ids)

# Save filtered versions to disk
nq_filtered.save_to_disk("wiki_nq_subset")
hotpot_filtered.save_to_disk("wiki_hotpot_subset")
