import json
import os
import threading
import traceback
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

import datasets
import wikipediaapi
from SPARQLWrapper import JSON, SPARQLWrapper
from tqdm import tqdm


def build_wikipedia_bio_dataset(
    results,
    wiki_wiki,
    data_dir="data_dir",
    base_name="wikipedia_bio",
    min_content_length=50,
    max_content_length=1000,
    max_workers=15,
):
    """
    Parallel fetch of Wikipedia summaries, creation and filtering of HF dataset.
    Saves intermediate raw dataset to {base_name}_raw and filtered to {base_name}.
    Returns the final filtered dataset.
    """
    DATA_PATH = Path(data_dir)
    raw_save_path = DATA_PATH / f"{base_name}_raw"
    final_save_path = DATA_PATH / base_name

    # If final filtered dataset exists, just load and return it.
    if final_save_path.exists():
        print(f"Filtered dataset already exists at {final_save_path}, loading.")
        return datasets.load_from_disk(str(final_save_path))

    # If raw exists but final doesn't, load raw and apply filtering.
    if raw_save_path.exists():
        print(f"Raw dataset found at {raw_save_path}, loading and filtering.")
        hf_dataset = datasets.load_from_disk(str(raw_save_path))
    else:
        # Need to build raw dataset from scratch.
        print("Building raw dataset in parallel...")
        bindings = results.get("results", {}).get("bindings", [])
        title_set = set()
        title_lock = threading.Lock()
        data_lock = threading.Lock()

        dataset_samples = {
            "title": [],
            "content": [],
            "gender": [],
            "birth_date": [],
            "occupations": [],
        }

        def process(result):
            try:
                item_label = result.get("itemLabel", {}).get("value", "")
                if "wikiTitle" not in result:
                    print(f"No Wikipedia sitelink found for: {item_label}")
                    return

                wiki_title = result["wikiTitle"]["value"]

                # dedupe titles
                with title_lock:
                    if wiki_title in title_set:
                        return
                    title_set.add(wiki_title)

                page_py = wiki_wiki.page(wiki_title)
                if not page_py.exists():
                    print(f"Page not found for Wikipedia title: {wiki_title}")
                    return

                summary_content = page_py.summary or ""

                gender = result.get("genderLabel", {}).get("value", None)
                birth_date = result.get("birthDate", {}).get("value", None)
                occupations = result.get("occupations", {}).get("value", None)

                with data_lock:
                    dataset_samples["title"].append(item_label)
                    dataset_samples["content"].append(summary_content)
                    dataset_samples["gender"].append(gender)
                    dataset_samples["birth_date"].append(birth_date)
                    dataset_samples["occupations"].append(occupations)

            except Exception:
                print(
                    f"Exception processing result {result.get('itemLabel', {}).get('value', '<no label>')}:"
                )
                traceback.print_exc()

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(process, r) for r in bindings]
            for _ in tqdm(
                as_completed(futures), total=len(futures), desc="Processing pages"
            ):
                pass  # progress drive

        # Build HF dataset and save raw
        hf_dataset = datasets.Dataset.from_dict(dataset_samples)
        os.makedirs(raw_save_path.parent, exist_ok=True)
        hf_dataset.save_to_disk(str(raw_save_path))
        print(f"Raw HuggingFace dataset created and saved to {raw_save_path}.")

    # Apply length-based filtering
    print("Applying content length filter...")

    def length_filter(example):
        token_count = len(example["content"].split(" "))
        return min_content_length <= token_count <= max_content_length

    hf_dataset = hf_dataset.filter(length_filter)
    # Save filtered dataset
    hf_dataset.save_to_disk(str(final_save_path))
    print(f"Filtered dataset saved to {final_save_path}.")
    return hf_dataset


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--min_content_length", default=80, type=int)
    parser.add_argument("--max_content_length", default=350, type=int)
    args = parser.parse_args()

    # SPARQL query to get the list of human individuals with Wikipedia sitelinks
    sparql_query = """
    PREFIX schema: <http://schema.org/>
    SELECT DISTINCT ?item ?itemLabel ?birthDate ?gender ?genderLabel 
                (GROUP_CONCAT(DISTINCT ?occupationLabel; separator=", ") AS ?occupations) 
                ?wikiTitle
    WHERE {
    ?featuredArticle schema:about ?item.
    ?featuredArticle schema:inLanguage "en".
    ?featuredArticle wikibase:badge ?badge.
    ?item wdt:P31 wd:Q5;  # Instance of human
            wdt:P569 ?birthDate;  # Birth date
            wdt:P21 ?gender.  # Gender
    VALUES (?badge) {(wd:Q17437796)(wd:Q17437798)}
    OPTIONAL {?featuredArticle schema:about ?item;
                                schema:inLanguage "en";
                                schema:name ?wikiTitle.}
    OPTIONAL {?item wdt:P106 ?occupation.
                ?occupation rdfs:label ?occupationLabel.
                FILTER(LANG(?occupationLabel) = "en")}
    SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }
    }
    GROUP BY ?item ?itemLabel ?birthDate ?gender ?genderLabel ?wikiTitle
    ORDER BY ?itemLabel
    """

    sparql_endpoint = "https://query.wikidata.org/sparql"

    # Function to execute SPARQL query
    def execute_sparql_query(query):
        sparql = SPARQLWrapper(sparql_endpoint)
        sparql.setQuery(query)
        sparql.setReturnFormat(JSON)
        return sparql.query().convert()

    # Function to save results as JSON
    def save_results_to_json(results, filename="wikidata_results.json"):
        with open(filename, "w") as json_file:
            json.dump(results, json_file)

    # Function to load results from JSON
    def load_results_from_json(filename="wikidata_results.json"):
        if os.path.exists(filename):
            with open(filename, "r") as json_file:
                return json.load(json_file)
        return None

    # Initialize Wikipedia API with a proper user agent
    wiki_wiki = wikipediaapi.Wikipedia(user_agent="BiographyDataExtraction/1.0")

    # Check if results file exists, if not, execute the SPARQL query and save results
    results_filename = "wikidata_results.json"
    if not os.path.exists(results_filename):
        results = execute_sparql_query(sparql_query)
        save_results_to_json(results, results_filename)
    else:
        results = load_results_from_json(results_filename)

    # Create a list to hold the dataset samples
    dataset_samples = {
        "title": [],
        "content": [],
        "gender": [],
        "birth_date": [],
        "occupations": [],
    }

    # Iterate through the results and fetch the Wikipedia content
    hf_dataset = build_wikipedia_bio_dataset(
        results,
        wiki_wiki,
        data_dir="data_dir",
        base_name="wikipedia_bio",
        min_content_length=args.min_content_length,
        max_content_length=args.max_content_length,
        max_workers=15,
    )
