import os
import json
import math
import asyncio
import numpy as np
import json_repair
import networkx as nx

from tqdm.asyncio import tqdm_asyncio
from collections import defaultdict

from textwrap import dedent
from typing import Any, Dict, List, Optional, Set, Callable

# Our RECON builds on publicly available LightRAG library
from lightrag.operate import _handle_single_relationship_extraction
from lightrag.utils import use_llm_func_with_cache

async def _retry(
    coro_factory: Callable[[], "asyncio.Future[Any]"],
    *,
    retries: int = 3,
    base_delay: float = 0.5,
) -> Any:
    for attempt in range(retries + 1):
        try:
            return await coro_factory()
        except TimeoutError:
            if attempt == retries:
                raise
            await asyncio.sleep(base_delay * (2 ** attempt))

def _neighbors_set(graph, node: str) -> Set[str]:
    try:
        return set(graph.neighbors(node))
    except Exception:
        # If node is missing or neighbors fails, fall back gracefully.
        return set()

def _get_node_attr(graph, node: str, key: str, default: Optional[str] = None) -> Optional[str]:
    try:
        return graph.nodes[node].get(key, default)
    except Exception:
        return default

def filter_nodes(graph):
    deg, _ = degree_table_unweighted_unique(graph)
    deg_vals = list(deg.values())
    th, _, _ = knee_threshold_with_index(deg_vals)
    nodes = [n for n, d in deg.items() if d < th]
    return th, nodes

def degree_table_unweighted_unique(H: nx.Graph):
    deg = dict(H.degree())
    nodes = list(H.nodes())
    return deg, nodes

def knee_threshold_with_index(values):
    y = np.sort(np.asarray(values, dtype=float))
    n = len(y)
    if n == 0:
        return 0.0, -1, y
    if n < 3 or np.allclose(y.min(), y.max()):
        return float(y[0]), 0, y

    x = np.linspace(0.0, 1.0, n)
    y_norm = (y - y.min()) / (y.max() - y.min() + 1e-12)

    p1 = np.array([x[0], y_norm[0]])
    p2 = np.array([x[-1], y_norm[-1]])
    v = p2 - p1
    v /= (np.linalg.norm(v) + 1e-12)

    diffs = np.stack([x - p1[0], y_norm - p1[1]], axis=1)
    proj = (diffs @ v)[:, None] * v[None, :]
    perp = diffs - proj
    dists = np.linalg.norm(perp, axis=1)

    idx = int(np.argmax(dists))
    return float(y[idx]), idx, y

def render_relation_prompt(query, source_entity, source_description, candidates):
    def build_candidates_block(candidates):
        parts = []
        for c in candidates:
            ent = c.get("ent", "").strip()
            des = c.get("des", "").strip()
            parts.append(f"entity: \"{ent}\"\ndescription: {des}")
        return "\n\n".join(parts)
    
    RELATION_PROMPT_TEMPLATE = dedent("""\
        You are a multimodal relation linker. Decide whether the CANDIDATE entities are actually connected to the SOURCE entity.
        Use only the provided text; do not assume facts not stated here. If uncertain, do not link.

        DECISION PRINCIPLES (STRICT)
        - Evidence must be explicit in the provided descriptions. Accept only if there is a clear, stated relation (e.g., "X acquired Y", "X is a subsidiary of Y", "X collaborates with Y", "X authored Y", "image shows X with Y label").
        - Do NOT infer based on topical similarity, shared categories, overlapping domains, or generic roles (e.g., both are "tech companies").
        - Name overlap alone is insufficient (e.g., common words, acronyms, partial matches, aliases) unless the description explicitly asserts they are the same or related entities.
        - Geographic or temporal co-occurrence without an explicit link is insufficient.
        - When visual/textual grounding is relevant, require a direct textual reference to the visual element (e.g., a caption or label). Otherwise, treat as ungrounded.

        MINIMUM EVIDENCE TO LINK (must satisfy at least one)
        - Direct statement of the relationship between SOURCE and candidate.
        - Unambiguous identifier match (ID, URL, ticker, handle) explicitly tying them.
        - Explicit cross-mention (SOURCE mentioned in candidate description or vice versa) describing the relation.

        REJECTION TRIGGERS (any one → reject)
        - Only thematic similarity or category overlap is present.
        - Only partial or fuzzy name match, unresolved acronym, or likely homonym.
        - Relationship requires outside knowledge or unstated assumptions.
        - The text is ambiguous, speculative ("may", "might", "possibly"), or lacks a concrete verb indicating a relation.

        OUTPUT FILTER
        - Compute a confidence score (rs) from 0–10 based solely on the provided text.
        - Include a candidate in the output ONLY if rs ≥ 7 and the rationale can quote or paraphrase the specific evidence phrase(s).
        - If NO candidate fits, output an empty list: [].

        RATIONALE REQUIREMENTS
        - 1–2 sentences, referencing the exact supporting phrase(s) from the given descriptions using short quotes where possible.
        - If applicable, state whether the link is textual↔visual (e.g., "caption shows", "label reads").

        OUTPUT FORMAT (exact)
        - Output a JSON list. Each element is a dict with exactly:
        {{
            "type": "rel",
            "se": "<source entity>", // must exactly match the provided source entity name
            "te": "<target candidate entity>", // must exactly match the provided candidate entity name
            "rd": "<1–2 sentence rationale referencing provided descriptions; mention textual↔visual if relevant>",
            "rk": ["<2–5 short keywords>", "..."],
            "rs": <integer 0–10 confidence score>
        }}
        - No extra text, no markdown, no trailing commas, no additional fields.

        INPUT
        [QUERY]
        {QUERY}

        [SOURCE]
        entity: "{SOURCE_ENTITY}"
        description: {SOURCE_DESCRIPTION}

        [CANDIDATES]
        {CANDIDATES}
    """).strip()

    return RELATION_PROMPT_TEMPLATE.format(
        QUERY=query,
        SOURCE_ENTITY=source_entity,
        SOURCE_DESCRIPTION=source_description,
        CANDIDATES=build_candidates_block(candidates),
    )

def _normalize_json_list(obj: Any) -> List[Dict[str, Any]]:
    """
    Accepts a raw object from json_repair. Ensures we always return a list[dict].
    Drops non-dict items.
    """
    if obj is None:
        return []
    if isinstance(obj, dict):
        return [obj]
    if isinstance(obj, list):
        return [x for x in obj if isinstance(x, dict)]
    # Unexpected scalar/string after repair: try to parse if it's a JSON string
    if isinstance(obj, str):
        try:
            parsed = json.loads(obj)
            return _normalize_json_list(parsed)
        except Exception:
            return []
    return []

# main inter-page connection function
async def inter_page_connection(
    global_config: Dict[str, Any],
    chunk_entity_relation_graph,
    entities_vdb,
    llm_response_cache: Optional[Dict[str, Any]] = None,
    *,
    max_concurrency: int = 12,  
    llm_retries: int = 3,       
    top_k: int = 10
) -> List[Dict[str, Any]]:
    graph = await chunk_entity_relation_graph._get_graph()
    threshold, filtered_nodes = filter_nodes(graph)

    queries: List[str] = []
    for node in filtered_nodes:
        description = _get_node_attr(graph, node, "description", "") or ""
        neighbors = list(_neighbors_set(graph, node))
        query = (
            f"Find entities associated with \"{node}\" (any relevant relation). "
            f"Context: {description} "
            f"Exclude the already linked entities: {neighbors}. "
        )
        queries.append(query)

    # Expand top_k based on threshold as in your original code
    top_kk = top_k + math.ceil(threshold / 10) * 10
    vdb_results: List[List[Dict[str, Any]]] = await entities_vdb.query_many(
        queries=queries, top_k=top_kk
    )

    # Keep only novel entity names and trim to top_k
    filtered_vdb_names: List[List[str]] = []
    for node, datas in zip(filtered_nodes, vdb_results):
        known: Set[str] = _neighbors_set(graph, node) | {node}
        names: List[str] = []
        for d in datas or []:
            name = d.get("entity_name")
            if name and name not in known:
                names.append(name)
        filtered_vdb_names.append(names[:top_k])

    llm_model_func = global_config["llm_model_func"]

    sem = asyncio.Semaphore(max_concurrency)
    results_accum: List[Dict[str, Any]] = []

    async def _process_one(node: str, candidate_names: List[str], query: str) -> List[Dict[str, Any]]:
        # Gather metadata with safe fallbacks
        description = _get_node_attr(graph, node, "description", "") or ""
        source_id_raw = _get_node_attr(graph, node, "source_id", "") or ""
        source_id = source_id_raw.split("<SEP>")[0] if "<SEP>" in source_id_raw else source_id_raw
        file_path = _get_node_attr(graph, node, "file_path", None)

        # Build candidate structures (skip missing nodes quietly)
        candidates = []
        for tgt in candidate_names:
            des = _get_node_attr(graph, tgt, "description", None)
            candidates.append({"ent": tgt, "des": des})

        # If we truly have no candidates or no meaningful description, skip to save tokens
        if not candidates:
            return []

        prompt = render_relation_prompt(
            query=query,
            source_entity=node,
            source_description=description,
            candidates=candidates,
        )

        async with sem:
            # Use your cache wrapper; add retries for transient failures
            async def _call():
                return await use_llm_func_with_cache(
                    input_text=prompt,
                    use_llm_func=llm_model_func,
                    llm_response_cache=llm_response_cache,
                    cache_type="cross_page_linking",
                )

            raw = await _retry(_call, retries=llm_retries)

        # Repair/normalize LLM JSON
        try:
            repaired = json_repair.repair_json(raw, return_objects=True)
        except Exception:
            # fall back to raw -> try to parse directly
            try:
                repaired = json.loads(raw)
            except Exception:
                return []

        rel_objs = _normalize_json_list(repaired)
        if not rel_objs:
            return []

        # Post-process each relationship (concurrently is OK if it does I/O)
        async def _post(rr: Dict[str, Any]) -> Optional[Dict[str, Any]]:
            return await _handle_single_relationship_extraction(
                record_attributes=rr,
                chunk_key=source_id,
                file_path=file_path,
            )

        processed = await asyncio.gather(
            *[ _post(rr) for rr in rel_objs ],
            return_exceptions=True
        )

        out: List[Dict[str, Any]] = []
        for item in processed:
            if isinstance(item, Exception) or item is None:
                continue
            out.append(item)
        return out

    # Kick off all tasks
    tasks = [
        asyncio.create_task(_process_one(node, names, q))
        for node, names, q in zip(filtered_nodes, filtered_vdb_names, queries)
    ]

    # Preserve order (gather) or stream (as_completed). Here we gather, then flatten.
    per_node_results = await tqdm_asyncio.gather(*tasks, total=len(tasks))
    results_accum = []
    for group in per_node_results:
        if isinstance(group, Exception):
            continue
        results_accum.extend(group)
    maybe_edges: dict[tuple[str, str], list] = defaultdict(list)
    for rel in results_accum:
        maybe_edges[(rel["src_id"], rel["tgt_id"])].append(rel)
    return maybe_edges, results_accum
