from __future__ import annotations
from typing import Dict, List, Tuple, Union
import torch
from rich import print
from rich.table import Table
from .atomize import extract_atoms
from .chunking import chunk_text, detect_lang
from refinement.aligner import Aligner
from .te_models import DualTE
from .priority import priority_rule
from config import CFG_TE, CFG_CHUNK

import sys

def _auto_device(devstr: str):
    """Return device string 'cuda:0' if available and requested, else 'cpu'."""
    if devstr == "auto":
        return "cuda:0" if torch.cuda.is_available() else "cpu"
    return devstr

def normalize_atoms(
    atoms: Union[Dict[str, List[str]], List[str]]
) -> Tuple[List[str], List[Tuple[str, str]]]:
    """
    Normalize atoms to a flat list and keep (category, atom) metadata.
    Returns:
      flat_atoms: List[str]          -> e.g., ["cat", "dog", "beach", ...]
      meta:       List[(cat, atom)]  -> e.g., [("characters","cat"), ...]
    """
    flat, meta = [], []
    if isinstance(atoms, dict):
        for cat, items in atoms.items():
            for it in items or []:
                s = (it or "").strip()
                if s:
                    flat.append(s)
                    meta.append((cat, s))
    elif isinstance(atoms, list):
        for it in atoms:
            s = (it or "").strip()
            if s:
                flat.append(s)
                meta.append(("unknown", s))
    else:
        raise TypeError(f"Unsupported atoms type: {type(atoms)}")
    return flat, meta

def verbalize_atom(atom: str, cat: str) -> str:
    """
    Turn a short atom into a simple sentence for TE robustness.
    """
    cat = (cat or "").lower()
    if cat == "characters":
        return f"There is a {atom}."
    if cat == "objects":
        return f"A {atom} is present."
    if cat == "actions":
        # If atom is verb-like, keep it as a predicate; otherwise fallback.
        return f"Someone is {atom}."
    if cat == "locations":
        return f"The scene is at a {atom}."
    if cat == "scenery":
        return f"The scenery includes {atom}."
    return f"The prompt mentions {atom}."

def verify(original_prompt: str, expanded_prompt: str, cfg=CFG_TE, chunk_cfg=CFG_CHUNK) -> Dict:
    """
    Verification pipeline:
    1) Atomize the original prompt into minimal hypotheses.
    2) Chunk the refined prompt (word-level or char-level) with overlap.
    3) Retrieve Top-K candidate chunks per atom via embeddings (recall only).
    4) Run dual TE per candidate pair with uncertainty trigger; arbitrate with priority_rule.
    5) Conservative aggregation per atom (CONTRAD > ENTAIL > WEAK_ENTAIL > NEUTRAL).
    6) Compute Coverage and Contradiction metrics.
    """
    # 1) Atomic extraction (from original prompt)
    atoms_raw = extract_atoms(original_prompt, use_llm=cfg.use_atomize_llm)
    print("verify.py line 76 atom_raw: ", atoms_raw)
    atoms_flat, atoms_meta = normalize_atoms(atoms_raw)  # <- NEW: flatten + keep categories
    # print("verify.py line 78 atom_flat atoms_meta:", atoms_flat, atoms_meta)

    # Early exit if no atoms
    if not atoms_flat:
        print("[yellow]No atoms extracted; coverage is 0, contradiction is 0.[/yellow]")
        return {
            "per_atom": [],
            "coverage": 0.0,
            "contradiction_rate": 0.0,
            "atoms": atoms_raw,
            "atoms_flat": atoms_flat,
            "atoms_meta": atoms_meta,
            "chunks": [],
            "topk_indices": [],
            "lang": "unknown",
        }

    # 2) Chunking (of refined prompt)
    # print(">>> chunk_cfg.mode =", chunk_cfg.mode)
    # print(">>> chunk_text is sentence-mode? ->", chunk_cfg.mode == "sentence")
    # sys.exit(0)  # Exit
    lang = chunk_cfg.lang if chunk_cfg.lang != "auto" else detect_lang(expanded_prompt)
    chunks: List[str] = chunk_text(
        expanded_prompt,
        mode=chunk_cfg.mode,
        lang=lang,
        max_words=chunk_cfg.max_words,
        overlap_words=chunk_cfg.overlap_words,
        max_chars=cfg.max_chunk_chars,
        overlap_chars=cfg.chunk_overlap,
        min_words=chunk_cfg.min_words,
    )
    # print("verify.py line 106 chunks: ", chunks)
    if not chunks:
        print("[yellow]No chunks produced from expanded prompt; nothing to verify.[/yellow]")

    # 3) Embedding retrieval (Top-K per atom)
    device = _auto_device(cfg.device)
    aligner = Aligner(cfg.emb_model, device=device)
    topk_indices = aligner.topk_indices(atoms_flat, chunks, cfg.topk) if chunks else [[] for _ in atoms_flat]
    # print("verify.py line 110 top_indices: ", topk_indices)

    # 4) Dual textual entailment with uncertainty trigger
    te = DualTE(cfg.te_model1, cfg.te_model2, device=device)  # <- use the resolved device
    per_atom = []
    for i, (cat, atom) in enumerate(atoms_meta):
        labels_for_atom = []
        for j in topk_indices[i]:
            cj = chunks[j]
            # Premise is the chunk; hypothesis is verbalized atom
            hyp = verbalize_atom(atom, cat)
            print('verify.py line 126', 'cj', cj, 'hyp', hyp)
            s1, s2 = te.infer(premise=cj, hypothesis=hyp)

            # Uncertainty trigger: if model1 is unsure, consult model2
            maxp1 = max(s1.get("entailment", 0.0), s1.get("contradiction", 0.0))
            if maxp1 < cfg.uncertain_maxp:
                lab = priority_rule(s1, s2)
            else:
                lab = priority_rule(s1, s1)  # rely on model1 only
            labels_for_atom.append(lab)
        # print('verify.py line 132 labels_for_atom: ', labels_for_atom)

        # 5) Conservative aggregation across candidates
        if "CONTRADICTION" in labels_for_atom:
            final = "CONTRADICTION"
        elif "ENTAILMENT" in labels_for_atom:
            final = "ENTAILMENT"
        elif "WEAK_ENTAIL" in labels_for_atom:
            final = "WEAK_ENTAIL"
        else:
            final = "NEUTRAL"
        per_atom.append({"category": cat, "atom": atom, "labels": labels_for_atom, "final": final})

    # 6) Metrics
    n = len(per_atom) or 1
    coverage = sum(1 for r in per_atom if r["final"] in ("ENTAILMENT", "WEAK_ENTAIL")) / n
    contrad  = sum(1 for r in per_atom if r["final"] == "CONTRADICTION") / n

    # Pretty print report
    table = Table(title="Verification Report (Entailment-centric)")
    table.add_column("i", justify="right")
    table.add_column("Category")
    table.add_column("Atom")
    table.add_column("Labels@TopK")
    table.add_column("Final")
    for i, r in enumerate(per_atom):
        table.add_row(str(i), r["category"], r["atom"], ", ".join(r["labels"]), r["final"])
    print(table)
    print(f"[bold]Coverage:[/bold] {coverage:.3f}    [bold]Contradiction rate:[/bold] {contrad:.3f}")

    return {
        "per_atom": per_atom,
        "coverage": coverage,
        "contradiction_rate": contrad,
        "atoms": atoms_raw,          # original dict or list (for debugging)
        "atoms_flat": atoms_flat,    # flattened list used for retrieval
        "atoms_meta": atoms_meta,    # (category, atom) pairs
        "chunks": chunks,
        "topk_indices": topk_indices,
        "lang": lang,
    }