import os
# 1️⃣ Pick an absolute path that has enough space
BASE = "./"

# 2️⃣ Point both caches there ─ before any HF import
os.environ["HF_HOME"]          = BASE          # makes <BASE>/hub and <BASE>/datasets
os.environ["HF_HUB_CACHE"]     = f"{BASE}/hub" # optional, explicit
os.environ["HF_DATASETS_CACHE"]= f"{BASE}/datasets" 

import json
import requests
import unicodedata
import time
import random
import numpy as np

import math
from p_tqdm import p_map
from tqdm import trange




WDQS_ENDPOINT = "https://query.wikidata.org/sparql"

# -------------------------------
# HTTP helpers (robust + polite)
# -------------------------------

def make_session(user_agent="YourAppName/1.0 (you@example.com)"):
    s = requests.Session()
    s.headers.update({
        "User-Agent": user_agent,
        "Accept": "application/sparql-results+json",
        "Accept-Encoding": "gzip, deflate",
    })
    return s

def request_json_with_backoff(session, url, *, params=None, max_retries=6, base_sleep=0.5):
    """
    Get JSON with WDQS-friendly backoff:
      - Honor Retry-After for 429
      - Back off on 502/503/504
      - Check content-type before json()
    """
    last_resp = None
    for attempt in range(1, max_retries + 1):
        resp = session.get(url, params=params, timeout=60)
        last_resp = resp
        status = resp.status_code

        # Handle throttling / transient errors
        if status in (429, 502, 503, 504):
            ra = resp.headers.get("Retry-After")
            if ra:
                try:
                    wait = float(ra)
                except ValueError:
                    wait = base_sleep * (2 ** (attempt - 1))
            else:
                wait = base_sleep * (2 ** (attempt - 1))
            wait += random.uniform(0.0, 0.5)  # jitter
            time.sleep(wait)
            continue

        # Non-2xx
        resp.raise_for_status()

        # Sanity check on content-type
        ctype = resp.headers.get("Content-Type", "")
        if "json" not in ctype.lower():
            snippet = resp.text[:200]
            raise RuntimeError(f"Expected JSON but got Content-Type={ctype}. Snippet: {snippet!r}")

        return resp.json()

    # Exhausted retries
    snippet = ""
    try:
        snippet = last_resp.text[:200] if last_resp is not None else ""
    except Exception:
        pass
    raise RuntimeError(f"Failed after {max_retries} attempts. Last status={getattr(last_resp,'status_code',None)}. Snippet: {snippet!r}")


# -------------------------------
# Query builders + callers
# -------------------------------

# -------------------------------
# Query builders + callers
# -------------------------------

def _build_query(qid: str, include_subclasses: bool, limit: int, offset: int, language: str) -> str:
    # P31 = instance of; P279 = subclass of
    path = "wdt:P31/wdt:P279*" if include_subclasses else "wdt:P31"
    return f"""
    SELECT ?item ?itemLabel ?qid WHERE {{
      ?item {path} wd:{qid} .
      BIND(STRAFTER(STR(?item), "entity/") AS ?qid)
      SERVICE wikibase:label {{ bd:serviceParam wikibase:language "{language},[AUTO_LANGUAGE]". }}
    }}
    LIMIT {int(limit)}
    OFFSET {int(offset)}
    """

def get_category_items(
    qid: str,
    *,
    step: int = 300,
    include_subclasses: bool = False,
    language: str = "en",
    user_agent: str = "YourAppName/1.0 (contact@example.com)",
    max_retries: int = 5,
    min_delay: float = 0.25,
    max_delay: float = 1.0,
):
    """
    Iterate over all items that are `instance of` wd:<qid>, requesting pages of size `step`.
    Stops when a page returns zero rows. Includes simple exponential backoff on transient errors.
    """
    offset = 0
    session = make_session(user_agent=user_agent)

    all_rows = []
    while True:
        query = _build_query(qid, include_subclasses, step, offset, language)

        # retry loop (wrap inside our backoff helper)
        for attempt in range(1, max_retries + 1):
            try:
                data = request_json_with_backoff(
                    session,
                    WDQS_ENDPOINT,
                    params={"query": query},
                    max_retries=6,
                    base_sleep=0.5,
                )
                rows = [
                    {
                        "Q_number": b["qid"]["value"],
                        "label": b["itemLabel"]["value"],
                        "uri": b["item"]["value"],
                    }
                    for b in data["results"]["bindings"]
                ]
                # If this page is empty, we are done
                if not rows:
                    return all_rows
                # Accumulate current page
                all_rows.extend(rows)

                # brief polite delay between pages
                time.sleep(random.uniform(min_delay, max_delay))
                break  # success; exit retry loop

            except (requests.exceptions.RequestException, ValueError, RuntimeError):
                if attempt == max_retries:
                    raise
                # exponential backoff with jitter
                time.sleep((2 ** (attempt - 1)) + random.uniform(0.0, 0.5))

        offset += step

    return all_rows

def get_item_properties(
    qid: str,
    *,
    language: str = "en",
    only_entity_values: bool = False,
    user_agent: str = "YourAppName/1.0 (you@example.com)"
):
    """
    Return all property values for a Wikidata item using the statement graph (p:/ps:),
    but keep the SAME output format as the original function.
    """
    session = make_session(user_agent=user_agent)

    query = f"""
    SELECT ?item ?itemLabel
           ?wdProp ?wdPropLabel
           (STRAFTER(STR(?wdProp), "/entity/") AS ?pid)
           ?value ?valueLabel
           (IF(isIRI(?value), "wikibase-item", DATATYPE(?value)) AS ?valueType)
           (IF(isIRI(?value), STRAFTER(STR(?value), "/entity/"), "") AS ?valueQid)
           ( ?wdPropLabel AS ?propertyLabel )
    WHERE {{
      VALUES ?item {{ wd:{qid} }}
      ?item ?p ?statement .
      ?wdProp wikibase:claim ?p .
      ?wdProp wikibase:statementProperty ?ps .
      ?statement ?ps ?value .
      SERVICE wikibase:label {{
        bd:serviceParam wikibase:language "{language},[AUTO_LANGUAGE]" .
      }}
    }}
    """

    data = request_json_with_backoff(
        session,
        WDQS_ENDPOINT,
        params={"query": query},
        max_retries=6,
        base_sleep=0.5,
    )

    # Initialize container with the item label if present
    item_label = None
    claims = {}

    for b in data["results"]["bindings"]:
        # Item label (same for all rows)
        if item_label is None and "itemLabel" in b:
            item_label = b["itemLabel"]["value"]

        pid = b["pid"]["value"]                           # e.g., "P571"
        prop_label = b.get("propertyLabel", {}).get("value", pid)

        # Build the value object with the SAME shape as before
        if b["valueType"]["value"] == "wikibase-item":
            value_obj = {
                "type": "wikibase-item",
                "Q_number": b.get("valueQid", {}).get("value", ""),
                "label": b.get("valueLabel", {}).get("value", ""),
                "uri": b["value"]["value"],
            }
        else:
            # Literal (strings, numbers, times, coords, etc.)
            value_obj = {
                "type": "literal",
                "datatype": b["valueType"]["value"],  # e.g., xsd:dateTime, geo:wktLiteral, xsd:decimal
                "value": b["value"]["value"],
            }

        if pid not in claims:
            claims[pid] = {"property_label": prop_label, "values": [value_obj]}
        else:
            claims[pid]["values"].append(value_obj)

    if not claims:
        # Fallback to direct property query if no statement-graph results
        fallback_query = f'''
        SELECT ?prop ?propLabel ?value ?valueLabel
            (STRAFTER(STR(?prop), "/prop/direct/") AS ?pid)
            (IF(isIRI(?value), "wikibase-item", DATATYPE(?value)) AS ?valueType)
            (IF(isIRI(?value), STRAFTER(STR(?value), "/entity/"), "") AS ?valueQid)
        WHERE {{
        VALUES ?item {{ wd:{qid} }}
        ?item ?prop ?value .
        FILTER(STRSTARTS(STR(?prop), "http://www.wikidata.org/prop/direct/"))
        SERVICE wikibase:label {{ bd:serviceParam wikibase:language "{language},[AUTO_LANGUAGE]" . }}
        }}
        '''
        fallback_data = request_json_with_backoff(
            session, WDQS_ENDPOINT,
            params={"query": fallback_query},
            max_retries=6, base_sleep=0.5,
        )
        for b in fallback_data["results"]["bindings"]:
            pid = b["pid"]["value"]
            prop_label = b.get("propLabel", {}).get("value", pid)
            if b["valueType"]["value"] == "wikibase-item":
                value_obj = {
                    "type": "wikibase-item",
                    "Q_number": b.get("valueQid", {}).get("value", ""),
                    "label": b.get("valueLabel", {}).get("value", ""),
                    "uri": b["value"]["value"],
                }
            else:
                value_obj = {
                    "type": "literal",
                    "datatype": b["valueType"]["value"],
                    "value": b["value"]["value"],
                }
            if pid not in claims:
                claims[pid] = {"property_label": prop_label, "values": [value_obj]}
            else:
                claims[pid]["values"].append(value_obj)

    return {"QID": qid, "label": item_label or qid, "claims": claims}


# pip install p_tqdm if you haven't already
from p_tqdm import p_map
import random, time

# --- Tunables (be polite to WDQS!) ---
PARALLELISM = 8        # 2–4 is a good citizen; raise carefully
PER_CALL_SLEEP = (0.15, 0.45)  # small jitter per process between calls
# NOTE: The new wbgetentities-batched path is preferred and ignores these.
all_qs = []
total_all_items = []
all_num = 0 
M = 20
              
with open(f"./1_Data_Gathering/temp_data/1_wikidata_random_labels_{M}M.jsonl", "r") as f:
    for line in f:
        all_num += 1
        item = json.loads(line)
        label = item['label'].strip()
        if (
            len(label) > 0 and
            not (label.startswith('Q') and label[1:].isdigit()) and
            not label.isdigit() and
            "wikidata" not in label.lower() and
            "wikipedia" not in label.lower() and
            "wikimedia" not in label.lower()
        ):
            total_all_items.append(item)
def _compute_related_tags_for_item(item,
                                   language="en",
                                   user_agent="YourAppName/1.0 (you@example.com)"):
    """
    Worker function run in parallel. Returns a *new* item dict with 'related_tags' added.
    """
    # Your all_items sometimes use 'Q_number' (from get_category_items), but the later loop used 'qid'.
    # Make this robust to either.
    qid = item.get("Q_number") or item.get("qid") or item.get("QID")
    if not qid:
        return {**item, "related_tags": [], "_error": "missing_qid"}

    # Gentle pause per request (helps reduce synchronized bursts across workers)
    time.sleep(random.uniform(*PER_CALL_SLEEP))

    try:
        props = get_item_properties(
            qid,
            language=language,
            only_entity_values=False,
            user_agent=user_agent
        )

        item_name = item.get("label", "")

        values_list = []
        for pid, prop_data in props.get("claims", {}).items():
            values = prop_data.get("values", [])
            # Skip properties with no values
            if not values:
                continue
            # Check every value in the list
            for v in values:
                # Only keep values that are Wikidata items
                if v.get("type") != "wikibase-item":
                    continue
                q_number = v.get("Q_number", "").strip()
                label_v = v.get("label", "")
                # Skip if no Q-number or if the value’s label is just the item name
                if not q_number:
                    continue
                if item_name and item_name in label_v:
                    continue
                values_list.append(label_v)

        # return a fresh dict (don’t mutate the input)
        return {**item, "related_tags": values_list}

    except Exception as e:
        # keep going even if a single item fails
        return {**item, "related_tags": [], "_error": f"{type(e).__name__}: {e}"}



import glob
from tqdm import tqdm

def _canon_label(s: str) -> str:
    """Canonicalize labels for stable, fast membership checks."""
    if not isinstance(s, str):
        return ""
    # Normalize + casefold + strip to unify variants
    return unicodedata.normalize("NFKC", s).casefold().strip()


def merge_items_with_tags(all_items, base_path):
    """
    Build a set of previously-seen labels (canonicalized) from JSONL files under
    base_path, then filter all_items by that set. Uses a set (O(1) membership)
    and streams the combined file to avoid large in-memory lists.
    """
    pattern = os.path.join(base_path, "*.jsonl")
    all_files = glob.glob(pattern)

    seen_labels = set()
    os.makedirs(base_path, exist_ok=True)

    # Write a combined JSONL of unique previous items while we scan, *without*
    # storing all previous items in RAM.
    save_path = f"{base_path}/all_prev_items.jsonl"
    with open(save_path, "w") as out_f:
        for file in tqdm(all_files, desc="Merging previous items"):
            try:
                with open(file, "r") as f:
                    for line in f:
                        try:
                            obj = json.loads(line)
                        except Exception:
                            continue
                        lbl = _canon_label(obj.get("label", ""))
                        if not lbl or lbl in seen_labels:
                            continue
                        seen_labels.add(lbl)
                        out_f.write(json.dumps(obj) + "\n")
            except FileNotFoundError:
                # If a file got moved/deleted during the scan, skip it
                continue

    print(f"Saved {len(seen_labels)} unique previous items -> {save_path}")
    print(f"Total previous unique labels: {len(seen_labels)}")

    # Fast O(1) membership using the canonicalized label
    filtered = [it for it in tqdm(all_items, desc="Filtering new items")
                if _canon_label(it.get("label", "")) not in seen_labels]
    return filtered

# Compute splits *after* filtering/merging and use the filtered list
base_path = "./1_Data_Gathering/temp_data/2_tag_only"
print(f"Merging items with tags from {base_path}")
print(f"Total items: {len(total_all_items)}")
all_items = merge_items_with_tags(total_all_items, base_path)
print(f"Total items after merging: {len(all_items)}")
print(f"Total items to process: {len(all_items)}")

split_size = 100000
num_splits = (len(all_items) // split_size)

for i in trange(num_splits, desc="Processing splits"):

    start_idx = i * split_size
    end_idx = start_idx + split_size
    out_put_path = f"{base_path}/items_with_tags_{start_idx}_{end_idx}.jsonl"

    print(f"[2_GET_TAGS] Processing items from index {start_idx} to {end_idx}")
    chunk = all_items[start_idx:end_idx]
    print(len(chunk))

    all_items_with_tags = p_map(
        _compute_related_tags_for_item,
        chunk,
        num_cpus=PARALLELISM,
        desc="Fetching item properties (WDQS)"
    )

    failures = [it for it in all_items_with_tags if it.get("_error")]
    if failures:
        print(f"{len(failures)} items failed; example:", failures[0].get("_error"))

    os.makedirs(base_path, exist_ok=True)
    with open(out_put_path, "w") as f:
        for item in all_items_with_tags:
            f.write(json.dumps(item) + "\n")

    if failures:
        with open(f"{base_path}/failed_items_{start_idx}_{end_idx}.jsonl", "w") as f:
            for item in failures:
                f.write(json.dumps(item) + "\n")


# ---------------------------------------------
# Process landmarks_low_freq and landmarks_high_freq
# ---------------------------------------------

def _escape_sparql_literal(value: str) -> str:
    return value.replace("\\", "\\\\").replace('"', '\\"')


def resolve_qid_by_exact_label(label: str,
                               language: str = "en",
                               user_agent: str = "YourAppName/1.0 (you@example.com)"):
    """
    Resolve a Wikidata QID by exact label match in the given language.
    Returns the first QID found or None if not found.
    """
    session = make_session(user_agent=user_agent)
    safe_label = _escape_sparql_literal(label)
    query = f"""
    PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
    SELECT (STRAFTER(STR(?item), "/entity/") AS ?qid)
    WHERE {{
      ?item rdfs:label "{safe_label}"@{language} .
    }}
    LIMIT 1
    """

    data = request_json_with_backoff(
        session,
        WDQS_ENDPOINT,
        params={"query": query},
        max_retries=6,
        base_sleep=1.0,
    )
    rows = data.get("results", {}).get("bindings", [])
    if not rows:
        return None
    return rows[0].get("qid", {}).get("value")


# ---------- landmarks_low_freq ----------
low_freq_path = "./1_Data_Gathering/temp_data/2_landmarks_low_freq.jsonl"
if os.path.exists(low_freq_path):
    prepared_low = []
    with open(low_freq_path, "r") as f:
        for line in f:
            try:
                obj = json.loads(line)
            except Exception:
                continue
            label = obj.get("itemLabel", "").strip()
            qid = obj.get("itemQ", "").strip()
            if not label or not qid:
                continue
            prepared_low.append({**obj, "label": label, "Q_number": qid})

    low_with_tags = p_map(
        _compute_related_tags_for_item,
        prepared_low,
        num_cpus=PARALLELISM,
        desc="Low freq: fetching item properties (WDQS)"
    )

    low_failures = [it for it in low_with_tags if it.get("_error")]
    if low_with_tags:
        with open("./1_Data_Gathering/temp_data/2_landmarks_low_freq.jsonl", "w") as f:
            for item in low_with_tags:
                f.write(json.dumps(item) + "\n")
    if low_failures:
        with open("./1_Data_Gathering/temp_data/2_failed_landmarks_low_freq.jsonl", "w") as f:
            for item in low_failures:
                f.write(json.dumps(item) + "\n")


# ---------- landmarks_high_freq ----------
high_freq_path = "./1_Data_Gathering/temp_data/2_landmarks_high_freq.jsonl"
if os.path.exists(high_freq_path):
    items_high = []
    with open(high_freq_path, "r") as f:
        for line in f:
            try:
                obj = json.loads(line)
            except Exception:
                continue
            label = obj.get("itemLabel", "").strip()
            if not label:
                continue
            items_high.append(obj)

    resolved_high = []
    resolve_failures = []
    for obj in items_high:
        label = obj.get("itemLabel", "").strip()
        # small polite delay per lookup
        time.sleep(random.uniform(*PER_CALL_SLEEP))
        try:
            qid = resolve_qid_by_exact_label(label, language="en")
        except Exception as e:
            qid = None
        if qid:
            resolved_high.append({**obj, "label": label, "Q_number": qid})
        else:
            resolve_failures.append({**obj, "_error": "qid_not_found"})

    high_with_tags = []
    if resolved_high:
        high_with_tags = p_map(
            _compute_related_tags_for_item,
            resolved_high,
            num_cpus=PARALLELISM,
            desc="High freq: fetching item properties (WDQS)"
        )

    high_failures = resolve_failures + [it for it in high_with_tags if it.get("_error")]
    if high_with_tags:
        with open("./1_Data_Gathering/temp_data/2_landmarks_high_freq.jsonl", "w") as f:
            for item in high_with_tags:
                f.write(json.dumps(item) + "\n")
    if high_failures:
        with open("./1_Data_Gathering/temp_data/2_failed_landmarks_high_freq.jsonl", "w") as f:
            for item in high_failures:
                f.write(json.dumps(item) + "\n")

