# -*- coding: utf-8 -*-
import json, random, concurrent.futures as cf
from typing import List, Dict, Any, Tuple, Optional
from tqdm import tqdm

from config import (
    ST_MODEL_NAME, NLI_MODEL_NAME, NLI_BATCH_SIZE,
    SIM_THRESHOLD, FALLBACK_TOP_M, 
)
from io_utils import normalize_text, _jsonl_path, init_jsonl, append_jsonl, select_by_sim_threshold
from embed import sim_query_to_cands
from wiki import fetch_sentences_for_anchor
from facts import gpt_extract_facts_from_snippets
from graph import filter_conflict_vs_truth, largest_neutral_clique
from fuse import fuse_layer_premise_from_facts, fuse_final_under_hypothesis
from nli import init_nli, NLICache

def _parallel_map(func, jobs, max_workers: int, desc: str):
    results = [None] * len(jobs)
    with cf.ThreadPoolExecutor(max_workers=max_workers) as ex:
        futs = {ex.submit(func, *args, **kwargs): i for i, (args, kwargs) in enumerate(jobs)}
        for fut in tqdm(cf.as_completed(futs), total=len(futs), desc=desc):
            i = futs[fut]
            try: results[i] = fut.result()
            except Exception: results[i] = None
    return results

def _extract_triplet(rec: Dict[str, Any]) -> Optional[Dict[str, str]]:
    if not isinstance(rec, dict): return None
    rs = rec.get("root_sample")
    if isinstance(rs, dict):
        p = (rs.get("premise") or "").strip()
        h = (rs.get("hypothesis") or "").strip()
        y = (str(rs.get("label") or "")).strip().lower()
        if p and h and y: return {"premise": p, "hypothesis": h, "label": y}
    p = (rec.get("premise") or "").strip()
    h = (rec.get("hypothesis") or "").strip()
    y = (str(rec.get("label") or "")).strip().lower()
    if p and h and y: return {"premise": p, "hypothesis": h, "label": y}
    return None

def load_roots_from_json(path: str, limit: int, *, offset: int = 0, shuffle: bool = False) -> List[Dict[str, str]]:
    data = json.load(open(path, "r", encoding="utf-8"))
    if isinstance(data, dict) and "data" in data: data = data["data"]
    if not isinstance(data, list): raise RuntimeError("json format error: need a list or {'data': [...] }")
    rows: List[Dict[str, str]] = []
    for r in data:
        item = _extract_triplet(r)
        if item: rows.append(item)
    if not rows: raise RuntimeError("no available (premise, hypothesis, label).")
    if shuffle: random.shuffle(rows)
    start = max(0, int(offset))
    end = start + int(limit) if limit else None
    return rows[start:end]

def run_pipeline(
    input_roots_json: str,
    roots: int = 1000,
    *,
    offset: int = 0,
    shuffle: bool = False,
    depth: int = 2,
    wiki_lang: str = "en",
    wiki_k_pages: int = 8,
    wiki_sent_max: int = 2,
    output_json: str = "gen_detail_fromjson.json",
    output_new: str  = "gen_samples_fromjson.json",
    nli_model_name: str = NLI_MODEL_NAME,
    nli_batch_size: int = NLI_BATCH_SIZE,
    sim_threshold: float = SIM_THRESHOLD,
    max_snippets: int = 0,
    fallback_top_m: int = FALLBACK_TOP_M,
    wiki_workers: int = 8,
    facts_k: int = 6,
    facts_k_fuse_premise: int = 4,
    max_facts_extract: int = 10,
    api_workers: int = 16,
    write_incremental: bool = True,
    truncate_existing_jsonl: bool = False,
    acc_rows: Optional[List[Dict[str, Any]]] = None,
    acc_flat: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, int]]:
    init_nli(nli_model_name)
    _ = NLICache(batch_size=nli_batch_size)

    detail_jsonl = _jsonl_path(output_json)
    flat_jsonl   = _jsonl_path(output_new)
    if write_incremental:
        init_jsonl(detail_jsonl, truncate_existing_jsonl)
        init_jsonl(flat_jsonl, truncate_existing_jsonl)

    roots_set = load_roots_from_json(input_roots_json, limit=roots, offset=offset, shuffle=shuffle)

    class S:
        __slots__ = ("root_premise","root_hypothesis","root_label",
                     "anchor_p","per_level","all_facts_truth_per_level",
                     "all_facts_fusion_per_level","truth_set_norm",
                     "did_any_fusion","alive","disqualified","hyp_pool")
    states: List[S] = []
    for r in roots_set:
        s = S()
        s.root_premise   = r["premise"]
        s.root_hypothesis= r["hypothesis"]
        s.root_label     = r["label"]
        s.anchor_p       = r["premise"]
        s.per_level      = []
        s.all_facts_truth_per_level  = []
        s.all_facts_fusion_per_level = []
        s.truth_set_norm = {normalize_text(r["premise"])}
        s.did_any_fusion = (depth == 0)
        s.alive = True
        s.disqualified = False
        s.hyp_pool = []
        states.append(s)

    rows: List[Dict[str, Any]] = []
    flat: List[Dict[str, Any]] = []
    stats = {"ok":0, "dropped_disqualified":0, "no_wiki":0, "below_threshold":0, "facts_empty":0}

    for lvl in range(1, depth+1):
        active_idx = [i for i,s in enumerate(states) if s.alive and not s.disqualified]
        if not active_idx: break

        wiki_jobs = [((states[i].anchor_p, wiki_lang, wiki_k_pages, wiki_sent_max), {"workers": wiki_workers}) for i in active_idx]
        wiki_results = _parallel_map(
            lambda anchor, lang, kpg, smax, workers=wiki_workers: fetch_sentences_for_anchor(anchor, lang, kpg, smax, workers=workers),
            wiki_jobs, max_workers=api_workers, desc=f"[L{lvl}] wiki"
        )

        fact_jobs, idx_for_fact = [], []
        selected_snippets_cache: Dict[int, List[str]] = {}

        for slot, i in enumerate(active_idx):
            s = states[i]
            pool_raw = wiki_results[slot] or []
            pool, seen = [], set()
            anchor_norm = normalize_text(s.anchor_p)
            for sent in pool_raw:
                n = normalize_text(sent)
                if not n or n in seen or n == anchor_norm: continue
                pool.append(sent); seen.add(n)

            if not pool:
                stats["no_wiki"] += 1; s.disqualified = True; s.alive = False; continue

            sims = sim_query_to_cands(s.anchor_p, pool)
            selected_snippets = select_by_sim_threshold(s.anchor_p, pool, sims, sim_threshold, max_snippets, fallback_top_m)
            if not selected_snippets:
                stats["below_threshold"] += 1; s.disqualified = True; s.alive = False; continue

            selected_snippets_cache[i] = selected_snippets
            fact_jobs.append(((selected_snippets, max_facts_extract), {}))
            idx_for_fact.append(i)

        fact_results = _parallel_map(
            lambda snippets, k: gpt_extract_facts_from_snippets(snippets, max_facts=k),
            fact_jobs, max_workers=api_workers, desc=f"[L{lvl}] facts"
        )

        layer_fuse_jobs, idx_for_fuse = [], []
        meta_cache_layer: Dict[int, Dict[str, Any]] = {}
        facts_for_layer_cache: Dict[int, List[str]] = {}

        for slot, i in enumerate(idx_for_fact):
            s = states[i]
            facts_raw = fact_results[slot] or []
            if not facts_raw:
                stats["facts_empty"] += 1; s.disqualified = True; s.alive = False; continue

            truth_texts = [s.anchor_p] + [f for prev in s.all_facts_truth_per_level for f in prev]
            facts_ok = filter_conflict_vs_truth(facts_raw, truth_texts)
            facts_clique = largest_neutral_clique(facts_ok)
            if not facts_clique:
                s.disqualified = True; s.alive = False; continue

            facts_kept = facts_clique[:min(facts_k, len(facts_clique))]
            if not facts_kept:
                s.disqualified = True; s.alive = False; continue

            facts_for_layer_fuse = facts_kept[:min(facts_k, len(facts_kept))]
            facts_for_layer_fuse = facts_for_layer_fuse[:min(facts_for_layer_fuse.__len__(), max(1, min(facts_kept.__len__(), facts_k)))]
            facts_for_layer_fuse = facts_for_layer_fuse[:min(facts_for_layer_fuse.__len__(), max(1, min(facts_kept.__len__(), facts_k)))]

            s.all_facts_truth_per_level.append(list(facts_clique))
            facts_for_layer_cache[i] = facts_for_layer_fuse

            meta_cache_layer[i] = {
                "level": lvl,
                "snippets_thresholded": selected_snippets_cache.get(i, []),
                "facts_raw": facts_raw,
                "facts_after_truth_conflict": facts_ok,
                "facts_clique_neutral": facts_clique,
                "facts_for_layer_fuse_k_prime": facts_for_layer_fuse,
            }
            layer_fuse_jobs.append(((s.anchor_p, facts_for_layer_fuse), {}))
            idx_for_fuse.append(i)

        layer_fuse_results = _parallel_map(
            lambda anchor, facts_for_layer_fuse: fuse_layer_premise_from_facts(anchor, facts_for_layer_fuse),
            layer_fuse_jobs, max_workers=api_workers, desc=f"[L{lvl}] fuse-premise"
        )

        hyp_fuse_jobs, idx_for_hyp = [], []
        a_prime_cache: Dict[int, str] = {}
        facts_for_b_cache: Dict[int, List[str]] = {}

        for slot, i in enumerate(idx_for_fuse):
            s = states[i]
            if s.disqualified: continue
            a_prime = (layer_fuse_results[slot] or s.anchor_p).strip()
            use_p = facts_for_layer_cache[i]

            if use_p: s.did_any_fusion = True
            s.all_facts_fusion_per_level.append(list(use_p))
            s.hyp_pool.extend(list(use_p))

            k_b = int(len(s.hyp_pool) * 0.4) + 1
            from embed import sim_query_to_cands
            uniq_pool = list(dict.fromkeys(s.hyp_pool))
            sims_b = sim_query_to_cands(s.root_hypothesis, uniq_pool)
            facts_for_b = [x for _, x in sorted(zip(sims_b, uniq_pool), key=lambda t: -t[0])][:max(1, k_b)]

            a_prime_cache[i] = a_prime
            facts_for_b_cache[i] = facts_for_b
            hyp_fuse_jobs.append(((s.root_hypothesis, facts_for_b), {}))
            idx_for_hyp.append(i)

        hyp_fuse_results = _parallel_map(
            lambda hyp, facts_final: fuse_final_under_hypothesis(hyp, facts_final),
            hyp_fuse_jobs, max_workers=api_workers, desc=f"[L{lvl}] fuse-hypo"
        )

        for slot, i in enumerate(idx_for_hyp):
            s = states[i]
            if s.disqualified: continue

            a_prime = a_prime_cache[i]
            b_prime, _meta_b = hyp_fuse_results[slot] if hyp_fuse_results[slot] else (s.root_hypothesis, {})
            use_p = facts_for_layer_cache[i]
            cache = meta_cache_layer[i]
            fused_hyp_ct = len(facts_for_b_cache[i])

            hop_record = {
                **cache,
                "premise_prime": a_prime,
                "hypothesis_prime": b_prime,
                "fused_premise_count": len(use_p),
                "fused_hypothesis_count": fused_hyp_ct,
            }
            s.per_level.append(hop_record)

            hop_row = {
                "level": cache.get("level", None),
                "premise": a_prime,
                "hypothesis": b_prime,
                "label": s.root_label,
                "fused_premise_count": len(use_p),
                "fused_hypothesis_count": fused_hyp_ct,
                "root_sample": {
                    "premise": s.root_premise,
                    "hypothesis": s.root_hypothesis,
                    "label": s.root_label
                }
            }
            flat.append(hop_row)
            if acc_flat is not None: acc_flat.append(hop_row)
            if write_incremental: append_jsonl(flat_jsonl, hop_row)

            s.truth_set_norm.add(normalize_text(a_prime))
            s.anchor_p = a_prime

        for s in states:
            if s.alive and not s.disqualified and len(s.per_level) < lvl:
                s.disqualified = True; s.alive = False

    finished_idx = [i for i,s in enumerate(states) if (not s.disqualified) and s.did_any_fusion and len(s.per_level)==depth]
    for i in finished_idx:
        s = states[i]
        detail_row = {
            "root_index": offset + i,
            "root_label": s.root_label,
            "root_premise": s.root_premise,
            "root_hypothesis": s.root_hypothesis,
            "depth": depth,
            "per_level": s.per_level,
        }
        rows.append(detail_row)
        if acc_rows is not None: acc_rows.append(detail_row)
        if write_incremental: append_jsonl(detail_jsonl, detail_row)

        stats["ok"] += 1

    return rows, flat, stats
