# -*- coding: utf-8 -*-
from typing import List, Tuple, Dict, Any
from openai_client import chat_json
from prompts import (
    PROMPT_SYS_LAYER_FUSE, build_user_layer_fuse,
    PROMPT_SYS_LAYER_FUSE_REPAIR, build_user_layer_repair,
    PROMPT_SYS_FINAL_FUSE, build_user_final_fuse,
    PROMPT_SYS_FINAL_FUSE_REPAIR, build_user_final_repair
)
from nli import nli_one
from config import TH_PPRIME_ANCHOR, TH_PPRIME_FACT, TH_FINAL_FACT, TH_FINAL_ANCHOR, TH_FINAL_CONTRA_WARN

def _entailed(p: str, h: str, th: float) -> bool:
    lab, prob, _ = nli_one(p, h)
    return lab == "entailment" and prob >= th

def fuse_layer_premise_from_facts(anchor_premise: str, facts_for_fuse: List[str]) -> str:
    if not facts_for_fuse: return anchor_premise.strip()
    out = chat_json(PROMPT_SYS_LAYER_FUSE, build_user_layer_fuse(anchor_premise, facts_for_fuse), temperature=0.1)
    p_prime = (out.get("premise_prime") or "").strip()
    if not p_prime: return anchor_premise.strip()

    need_fix = (not _entailed(p_prime, anchor_premise, TH_PPRIME_ANCHOR)) or any(not _entailed(p_prime, f, TH_PPRIME_FACT) for f in facts_for_fuse)
    if need_fix:
        missing = [f for f in facts_for_fuse if not _entailed(p_prime, f, TH_PPRIME_FACT)]
        fix = (chat_json(PROMPT_SYS_LAYER_FUSE_REPAIR, build_user_layer_repair(anchor_premise, p_prime, missing), temperature=0.0).get("premise_prime") or "").strip()
        if fix: p_prime = fix
    return p_prime

def fuse_final_under_hypothesis(root_hypothesis: str, facts_final: List[str]) -> Tuple[str, Dict[str, Any]]:
    facts_final = _dedupe(facts_final)
    if not facts_final: return root_hypothesis.strip(), {"note":"empty_layer_facts"}
    out = chat_json(PROMPT_SYS_FINAL_FUSE, build_user_final_fuse(root_hypothesis, facts_final), temperature=0.1)
    final_text = (out.get("final_text") or "").strip()

    uncovered = [f for f in facts_final if not _entailed(final_text, f, TH_FINAL_FACT)]
    l1,p1,_ = nli_one(root_hypothesis, final_text)
    l2,p2,_ = nli_one(final_text, root_hypothesis)
    anchor_bad = not _entailed(final_text, root_hypothesis, TH_FINAL_ANCHOR)
    contra_like = (l1 == "contradiction" and p1 >= TH_FINAL_CONTRA_WARN) or (l2 == "contradiction" and p2 >= TH_FINAL_CONTRA_WARN)

    if uncovered or anchor_bad or contra_like or not final_text:
        fix = chat_json(PROMPT_SYS_FINAL_FUSE_REPAIR, build_user_final_repair(root_hypothesis, final_text, uncovered), temperature=0.0)
        final_text = (fix.get("final_text") or "").strip()

    meta = {"root_vs_layer_text": {"root->layer": {"lab": l1, "p": p1}, "layer->root": {"lab": l2, "p": p2}}}
    return (final_text or root_hypothesis.strip()), meta

def _dedupe(items: List[str]) -> List[str]:
    seen = set(); out = []
    from io_utils import normalize_text
    for s in items or []:
        n = normalize_text(s)
        if not n or n in seen: continue
        seen.add(n); out.append(s.strip())
    return out
