import json
import asyncio
import json_repair

from tqdm import tqdm
from textwrap import dedent
from collections import Counter, defaultdict
from typing import Any, Dict, List, Tuple, Optional

from _prompts import PROMPTS

# Our RECON builds on publicly available LightRAG library
from lightrag.base import (
    BaseKVStorage,
    TextChunkSchema,
)
from lightrag.utils import (
    clean_str,
    is_float_regex,
    normalize_extracted_info,
    update_chunk_cache_list,
)
from lightrag.utils import update_chunk_cache_list, use_llm_func_with_cache

async def _handle_single_entity_extraction(
    record_attributes: dict,
    chunk_key: str,
    file_path: str = "unknown_source",
):
    if any([
        'en' not in record_attributes,
        'et' not in record_attributes,
        'ed' not in record_attributes,
        'es' not in record_attributes,
    ]):
        return None

    entity_name = clean_str(record_attributes['en']).strip()
    if not entity_name:
        print(
            f"Entity extraction error: empty entity name in: {record_attributes}"
        )
        return None

    entity_name = normalize_extracted_info(entity_name, is_entity=True).upper()
    if not entity_name or not entity_name.strip():
        print(
            f"Entity extraction error: entity name became empty after normalization. Original: '{record_attributes[1]}'"
        )
        return None

    entity_type = clean_str(record_attributes['et']).strip('"').upper()
    if not entity_type.strip() or entity_type.startswith('("'):
        print(
            f"Entity extraction error: invalid entity type in: {record_attributes}"
        )
        return None

    entity_description = clean_str(record_attributes['ed'])
    entity_description = normalize_extracted_info(entity_description)
    entity_source = record_attributes['es']

    if not entity_description.strip():
        print(
            f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'"
        )
        return None

    return dict(
        entity_name=entity_name,
        entity_type=entity_type,
        entity_source=entity_source,
        description=entity_description,
        source_id=chunk_key,
        file_path=file_path,
    )

async def _handle_single_relationship_extraction(
    record_attributes: dict,
    chunk_key: str,
    file_path: str = "unknown_source",
):
    if any([
        'se' not in record_attributes,
        'te' not in record_attributes,
        'rd' not in record_attributes,
        'rk' not in record_attributes,
        'rs' not in record_attributes,
    ]):
        return None

    source = clean_str(record_attributes['se'])
    target = clean_str(record_attributes['te'])

    source = normalize_extracted_info(source, is_entity=True).upper()
    target = normalize_extracted_info(target, is_entity=True).upper()

    if not source or not source.strip():
        print(
            f"Relationship extraction error: source entity became empty after normalization. Original: '{record_attributes['source_entity']}'"
        )
        return None

    if not target or not target.strip():
        print(
            f"Relationship extraction error: target entity became empty after normalization. Original: '{record_attributes['target_entity']}'"
        )
        return None

    if source == target:
        print(
            f"Relationship source and target are the same in: {record_attributes}"
        )
        return None

    edge_description = clean_str(record_attributes['rd'])
    edge_description = normalize_extracted_info(edge_description)

    edge_keywords = ",".join(record_attributes['rk'])
    edge_keywords = normalize_extracted_info(
        clean_str(edge_keywords), is_entity=True
    )
    edge_keywords = edge_keywords.replace("，", ",")

    edge_source_id = chunk_key
    weight = (
        float(str(record_attributes['rs']).strip('"').strip("'"))
        if is_float_regex(str(record_attributes['rs']).strip('"').strip("'"))
        else 1.0
    )
    return dict(
        src_id=source,
        tgt_id=target,
        weight=weight,
        description=edge_description,
        keywords=edge_keywords,
        source_id=edge_source_id,
        file_path=file_path,
    )

async def _handle_single_reflector_output(
    record_attributes: str,
    reflect_missing_nodes: list,
    extract_result_safe: dict, 
    max_entities: int = 10, 
    max_relations: int = 3,
):
    if isinstance(record_attributes, str):
        record_attributes = json_repair.repair_json(record_attributes, return_objects=True)
    
    if "potential_missing_entities" not in record_attributes:
        record_attributes["potential_missing_entities"] = []
    if "potential_missing_relations" not in record_attributes:
        record_attributes["potential_missing_relations"] = []
    if "done_with_this_page" not in record_attributes:
        record_attributes["done_with_this_page"] = {
            "reason": "n/a",
            "done_extract": False,
        }

    _record_attributes = {
        "potential_missing_entities": [],
        "potential_missing_relations": [],
        "done_with_this_page": {
            "reason": "n/a",
            "done_extract": False,
        }
    }
    for entity in record_attributes["potential_missing_entities"]:
        if isinstance(entity, str):
            _record_attributes["potential_missing_entities"].append(entity)
    _record_attributes["potential_missing_entities"] = list(
        set(_record_attributes["potential_missing_entities"])
    )[:max_entities]
    _record_attributes["potential_missing_entities"] += reflect_missing_nodes

    for relation in record_attributes["potential_missing_relations"]:
        if any([
            "entity" not in relation,
            "missing_relations" not in relation,
        ]):
            continue
        if not isinstance(relation["entity"], str):
            continue
        _missing_relations = []
        for missing_relation in relation["missing_relations"]:
            if isinstance(missing_relation, str):
                _missing_relations.append(missing_relation)
        _record_attributes["potential_missing_relations"].append(
            {
                "entity": relation["entity"],
                "missing_relations": list(set(_missing_relations))[:max_relations]
            }
        )

    if any([
        "reason" not in record_attributes["done_with_this_page"],
        "done_extract" not in record_attributes["done_with_this_page"],
    ]):
        ...
    else:
        reason = str(record_attributes["done_with_this_page"]["reason"])
        done_extract = record_attributes["done_with_this_page"]["done_extract"]
        _record_attributes["done_with_this_page"]["reason"] = reason
        _record_attributes["done_with_this_page"]["done_extract"] = done_extract

    previous_extracted_entities  = "\n".join([f'\'{res["entity_name"]}\'' for res in extract_result_safe["nodes"]])
    previous_extracted_relations = "\n".join([f'- {res["src"]} -> {res["tgt"]}' for res in extract_result_safe["edges"]])
    
    potential_missing_entities   = "\n".join([f'\'{res}\'' for res in list(set(_record_attributes["potential_missing_entities"]))]) if len(_record_attributes["potential_missing_entities"]) > 0 else "(None)"
    potential_missing_relations  = "\n".join([
        '- {ent}\n    {rel}'.format(
            ent=d['entity'],
            rel=d['missing_relations'],
        ) for d in _record_attributes['potential_missing_relations']
    ]) if 'potential_missing_relations' in _record_attributes else "(None)"

    return (
        _record_attributes, 
        previous_extracted_entities, 
        previous_extracted_relations, 
        potential_missing_entities,
        potential_missing_relations
    )

def results_to_jsonable(results):
    jsonable = []
    for maybe_nodes, maybe_edges in results:
        # nodes: flatten the defaultdict[str, list[dict]]
        nodes_list = []
        for _name, entities in dict(maybe_nodes).items():
            nodes_list.extend(entities)

        # edges: turn tuple keys into explicit src/tgt fields
        edges_list = []
        for (src, tgt), rel_list in dict(maybe_edges).items():
            for rel in rel_list:
                edges_list.append({"src": src, "tgt": tgt, **rel})

        jsonable.append({"nodes": nodes_list, "edges": edges_list})
    return jsonable

def construct_reflector_inputs(
    nodes: List[Dict[str, Any]],
    edges: List[Dict[str, Any]],
    # summary: str,
    entity_types: str,
    page_id: int,
    total_pages: int,
    *,
    max_relations_per_node: Optional[int] = None,
) -> str | Tuple[str, Dict[str, Any]]:
    node_names_list: List[str] = []
    malformed_nodes: List[int] = []
    for i, n in enumerate(nodes):
        name = n.get("entity_name")
        if isinstance(name, str) and name.strip():
            node_names_list.append(name.strip())
        else:
            malformed_nodes.append(i)

    if malformed_nodes:
        raise ValueError(f"Malformed nodes (missing/empty 'entity_name') at indices: {malformed_nodes}")

    node_names_set = set(node_names_list)
    node_names_sorted = sorted(node_names_set)

    degree: Counter[str] = Counter({name: 0 for name in node_names_set})
    relations_by_node: Dict[str, List[Tuple[str, str, str]]] = defaultdict(list)
    missing_nodes: set[str] = set()
    malformed_edges: List[int] = []

    for i, e in enumerate(edges):
        src = e.get("src")
        tgt = e.get("tgt")
        desc = e.get("description")
        if not (isinstance(src, str) and isinstance(tgt, str) and isinstance(desc, str)):
            malformed_edges.append(i)
            continue
        src, tgt, desc = src.strip(), tgt.strip(), desc.strip()

        if src not in node_names_set:
            missing_nodes.add(src)
        if tgt not in node_names_set:
            missing_nodes.add(tgt)

        if src in node_names_set and tgt in node_names_set:
            degree[src] += 1
            degree[tgt] += 1
            relations_by_node[src].append((src, tgt, desc))
            relations_by_node[tgt].append((src, tgt, desc))

    def rel_key(triple: Tuple[str, str, str]) -> Tuple[str, str, str]:
        return triple[0], triple[1], triple[2]

    node_relations_lines: List[str] = []
    for node in node_names_sorted:
        node_relations_lines.append(f"Entity: {node}")
        triples = relations_by_node.get(node, [])
        unique_sorted = sorted({rel_key(t) for t in triples})
        if max_relations_per_node is not None:
            unique_sorted = unique_sorted[:max_relations_per_node]
        if len(unique_sorted) > 0:
            for s, t, d in unique_sorted:
                node_relations_lines.append(f"-> Src: {s} | Tgt: {t} | Rel: {d}")
        else:
            node_relations_lines.append("(missing relations)")

    node_relations_block = "\n".join(node_relations_lines) if node_relations_lines else "(No relations observed)"

    instruction_block = dedent(
        """
        You are an entity–relation reflector.
        Given a single primary page image, along with the entities, and relations extracted before,
        analyze entities and their existing relations, then propose missing entities and relation types. 

        Return your answer in **JSON only** (no extra text). Follow the JSON schema exactly:
        {
            "potential_missing_entities": string[],                
            "potential_missing_relations": [                       
                {
                    "entity": string,
                    "missing_relations": string[]                  
                }
            ],
            "done_with_this_page": {
                "reason": string,
                "done_extract": boolean
            }
        }

        PAGE-GROUNDED EVIDENCE ONLY:
        • Suggestions must be supported by this page image: text, caption, label, clearly visible figure/diagram/table/icon.
        • Cross-page or speculative links are not allowed.

        POTENTIAL MISSING ENTITIES:
        • Propose an entity only if it (a) visibly/textually appears on THIS page.

        VISUAL–TEXT LINKS:
        • Propose visual-to-text relations when a specific on-page visual exists
        (e.g., "Illustrated By Figure", "Described In Caption", "Summarized In Table", "Depicted In Diagram").

        DONE WITH THIS PAGE:
        • Provide a short reason.
        • Set done_extract = true if remaining critical links (if any) likely require other pages; otherwise false.

        Example:
        {
            "potential_missing_entities": [
                "Workflow Overview"
            ],
            "potential_missing_relations": [
                {
                    "entity": "Method A",
                    "missing_relations": [
                        "Evaluated On Dataset",
                        "Illustrated By Figure"
                    ]
                },
                {
                    "entity": "Equation (1) - Training Loss",
                    "missing_relations": [
                        "Optimizes Method"
                    ]
                }
            ],
            "done_with_this_page": {
                "reason": "The page shows Method A and Equation (1). The workflow figure is on this page, but the evaluation dataset is not; linking Method A to that dataset requires other pages.",
                "done_extract": true
            }
        }
        """
    ).strip()

    prompt_template = dedent(
        """
        {instructions}

        MetaData:
        • Entity types:
          {entity_types}

        • Total Document Pages: {total_page}, Now on Page: {now_page}
        
        Context:
        • Entities extracted in the last iteration:
          {entity_list}

        Relations observed in the last iteration (by entity):
        {node_relations}
        """
    ).strip()

    prompt = prompt_template.format(
        instructions=instruction_block,
        entity_types=entity_types,
        total_page=total_pages,
        now_page=page_id,
        entity_list=", ".join(node_names_sorted) if node_names_sorted else "(none)",
        node_relations=node_relations_block,
    )
    return prompt, list(missing_nodes)

# main intra-page entity-relation extraction and reflection function
async def intra_page_extraction_reflection(
    chunks: dict[str, TextChunkSchema],
    global_config: dict[str, str],
    pipeline_status: dict = None,
    pipeline_status_lock=None,
    llm_response_cache: BaseKVStorage | None = None,
    text_chunks_storage: BaseKVStorage | None = None,
    use_llm_func: callable = None,
    entity_extract_max_reflecting: int = 5,
    chunk_max_async: int = 12,
) -> list:
    PROMPTS_ENTITY_EXTRACTION = PROMPTS["multimodal_entity_extraction_init"]
    PROMPTS_ENTITY_EXTRACTION_EXAMPLES = PROMPTS["multimodal_entity_extraction_examples"]
    
    ordered_chunks = list(chunks.items())
    language = global_config["addon_params"].get(
        "language", PROMPTS["DEFAULT_LANGUAGE"]
    )
    entity_types = global_config["addon_params"].get(
        "entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"]
    )

    example_number = global_config["addon_params"].get("example_number", None)
    if example_number and example_number < len(PROMPTS_ENTITY_EXTRACTION_EXAMPLES):
        examples = "\n".join(
            PROMPTS_ENTITY_EXTRACTION_EXAMPLES[: int(example_number)]
        )
    else:
        examples = "\n".join(PROMPTS_ENTITY_EXTRACTION_EXAMPLES)

    example_context_base = dict(
        entity_types=", ".join(entity_types),
        language=language,
    )
    examples = examples.format(**example_context_base)

    entity_extract_prompt = PROMPTS_ENTITY_EXTRACTION
    context_base = dict(
        entity_types=",".join(entity_types),
        examples=examples,
        language=language,
    )

    processed_chunks = 0
    total_chunks = len(ordered_chunks)

    async def _process_extraction_result(
        result: str,
        chunk_key: str,
        file_path: str = "unknown_source"
    ):
        maybe_nodes: dict[str, list] = defaultdict(list)
        maybe_edges: dict[tuple[str, str], list] = defaultdict(list)

        json_payload: list | None = None
        try:
            json_payload = json.loads(result)
        except json.JSONDecodeError:
            json_payload = json_repair.repair_json(result, return_objects=True)
        if json_payload is not None:
            if not isinstance(json_payload, list):
                json_payload = [json_payload]

            for item in json_payload:
                if not isinstance(item, dict):
                    continue

                match item.get("type"):
                    case "ent":
                        entity = await _handle_single_entity_extraction(
                            item, chunk_key, file_path
                        )
                        if entity:
                            maybe_nodes[entity["entity_name"]].append(entity)

                    case "rel":
                        rel = await _handle_single_relationship_extraction(
                            item, chunk_key, file_path
                        )
                        if rel:
                            maybe_edges[(rel["src_id"], rel["tgt_id"])].append(rel)
        return maybe_nodes, maybe_edges    
    
    async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
        nonlocal processed_chunks

        chunk_key = chunk_key_dp[0]
        chunk_dp = chunk_key_dp[1]
        page_id = chunk_dp["chunk_order_index"] + 1
        total_pages = chunk_dp["total_chunks"]

        images = [chunk_dp["page_img"]] if "page_img" in chunk_dp else []
        file_path = chunk_dp.get("file_path", "unknown_source")
        cache_keys_collector = []

        hint_prompt = entity_extract_prompt.format(
            **{
                **context_base, 
                "total_page": total_pages, 
                "now_page": page_id,
            }
        )
        final_result = await use_llm_func_with_cache(
            hint_prompt,
            use_llm_func,
            input_images=images,
            llm_response_cache=llm_response_cache,
            cache_type="extract",
            chunk_id=chunk_key,
            cache_keys_collector=cache_keys_collector,
        )
        maybe_nodes, maybe_edges = await _process_extraction_result(
            final_result, chunk_key, file_path
        )
        extract_result_safe = results_to_jsonable([(maybe_nodes, maybe_edges)])[0]

        for _ in range(entity_extract_max_reflecting):
            reflect_prompt, reflect_missing_nodes = construct_reflector_inputs(
                nodes=extract_result_safe["nodes"],
                edges=extract_result_safe["edges"],
                entity_types=entity_types,
                page_id=page_id,
                total_pages=total_pages,
            )
            reflect_results = await use_llm_func_with_cache(
                reflect_prompt,
                use_llm_func,
                input_images=images,
                llm_response_cache=llm_response_cache,
                cache_type="reflect",
                chunk_id=chunk_key,
                cache_keys_collector=cache_keys_collector,
            )
            
            (
                reflect_results, 
                previous_extracted_entities, 
                previous_extracted_relations, 
                potential_missing_entities,
                potential_missing_relations
            ) = await _handle_single_reflector_output(
                reflect_results, 
                reflect_missing_nodes, 
                extract_result_safe
            )
            before_reflect_node_lens  = len(extract_result_safe["nodes"])
            before_reflect_edges_lens = len(extract_result_safe["edges"])
            continue_prompt = PROMPTS["multimodal_entity_extraction_continue"].format(
                **{
                    **context_base, 
                    "total_page": total_pages, 
                    "now_page": page_id,
                    "previous_extracted_entities": previous_extracted_entities,
                    "previous_extracted_relations": previous_extracted_relations,
                    "potential_missing_entities": potential_missing_entities,
                    "potential_missing_relations": potential_missing_relations,
                }
            )

            reflect_result = await use_llm_func_with_cache(
                continue_prompt,
                use_llm_func,
                input_images=images,
                llm_response_cache=llm_response_cache,
                cache_type="extract",
                chunk_id=chunk_key,
                cache_keys_collector=cache_keys_collector,
            )

            # Process reflecting result separately with file path
            reflect_nodes, reflect_edges = await _process_extraction_result(
                reflect_result, chunk_key, file_path
            )
            extract_result_safe = results_to_jsonable([(reflect_nodes, reflect_edges)])[0]
        
            # Merge results - only add entities and edges with new names
            for entity_name, entities in reflect_nodes.items():
                if (
                    entity_name not in maybe_nodes
                ):  # Only accetp entities with new name in reflecting stage
                    maybe_nodes[entity_name].extend(entities)
            for edge_key, edges in reflect_edges.items():
                if (
                    edge_key not in maybe_edges
                ):  # Only accetp edges with new name in reflecting stage
                    maybe_edges[edge_key].extend(edges)

            extract_result_safe = results_to_jsonable([(maybe_nodes, maybe_edges)])[0]
            after_reflect_node_lens  = len(extract_result_safe["nodes"])
            after_reflect_edges_lens = len(extract_result_safe["edges"])

            node_len_delta = after_reflect_node_lens - before_reflect_node_lens
            edge_len_delta = after_reflect_edges_lens - before_reflect_edges_lens
            extract_done = reflect_results['done_with_this_page']['done_extract']

            if extract_done or (node_len_delta == 0 and edge_len_delta == 0):
                break

        if cache_keys_collector and text_chunks_storage:
            await update_chunk_cache_list(
                chunk_key,
                text_chunks_storage,
                cache_keys_collector,
                "entity_extraction",
            )

        processed_chunks += 1
        entities_count = len(maybe_nodes)
        relations_count = len(maybe_edges)
        log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel"
        if pipeline_status is not None:
            async with pipeline_status_lock:
                pipeline_status["latest_message"] = log_message
                pipeline_status["history_messages"].append(log_message)

        return maybe_nodes, maybe_edges

    # Get max async tasks limit from global_config
    semaphore = asyncio.Semaphore(chunk_max_async)
    chunk_timings: dict[str, dict] = {}
    chunk_results = []

    def _now():
        return asyncio.get_running_loop().time()

    async def _process_with_semaphore(chunk):
        # [timing+] mark when the task is queued and when it starts waiting for the semaphore
        queued_at = _now()  # when task created / scheduled
        async with semaphore:
            acquired_at = _now()  # when concurrency slot acquired
            try:
                return await _process_single_content(chunk)
            except Exception as e:
                _chunk_key = chunk[0] if isinstance(chunk, tuple) else str(chunk)
                print(f"Chunk {_chunk_key} failed: {e}")
            finally:
                finished_at = _now()
                # chunk is ("chunk_key", TextChunkSchema)
                _chunk_key = chunk[0] if isinstance(chunk, tuple) else str(chunk)
                chunk_timings[_chunk_key] = {
                    "queued_at": queued_at,
                    "acquired_at": acquired_at,
                    "finished_at": finished_at,
                    "wait_ms": (acquired_at - queued_at) * 1000.0,
                    "run_ms": (finished_at - acquired_at) * 1000.0,
                    "total_ms": (finished_at - queued_at) * 1000.0,
                }
    
    tasks = [asyncio.create_task(_process_with_semaphore(c)) for c in ordered_chunks]
    bar = tqdm(total=len(tasks), desc="Processing chunks", unit="chunk")

    try:
        for completed in asyncio.as_completed(tasks):
            try:
                chunk_results.append(await completed)
            except Exception: 
                for t in tasks:
                    t.cancel()
                await asyncio.gather(*tasks, return_exceptions=True)
            finally:
                bar.update(1)
    finally:
        bar.close()

    if chunk_timings:
        try:
            worst = sorted(chunk_timings.items(), key=lambda kv: kv[1]["run_ms"], reverse=True)[:5]
            for ck, t in worst:
                print(
                    f"[timing] {ck}: wait={t['wait_ms']:.1f}ms run={t['run_ms']:.1f}ms total={t['total_ms']:.1f}ms"
                )
        except Exception:
            pass
        if pipeline_status is not None and pipeline_status_lock is not None:
            async with pipeline_status_lock:
                pipeline_status.setdefault("chunk_timings", {}).update(chunk_timings)

    return chunk_results

