import re
import os
import json
import time
import argparse
import difflib
import hashlib
from pathlib import Path
from typing import List, Dict, Tuple, Optional

_HAVE_SK = False
try:
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics.pairwise import cosine_similarity
    _HAVE_SK = True
except Exception:
    pass

_HAVE_ST = False
try:
    from sentence_transformers import SentenceTransformer
    import numpy as np
    _HAVE_ST = True
except Exception:
    import numpy as np  # ensure numpy exists
    _HAVE_ST = False

from utils import load_file_as_string, remote_chat
try:
    from eval_citation_modified_single import nli, get_refs
except Exception:
    from main import nli, get_refs

def get_refs_enhanced(paper_dir: Path, index_type: str = "full_text") -> Dict[str, str]:
    bibname2content = {}
    bibname_mapping = {}  
    
    for file in os.listdir(paper_dir):
        p = paper_dir / file
        try:
            with open(p, "r", encoding='utf-8') as f:
                paper_dict = json.load(f)
        except:
            continue
            
        bib_name = paper_dict.get("bib_name", "").strip()
        if not bib_name:
            continue
            
        if index_type == "abstract":
            content = paper_dict.get("abstract", "").strip()
            if content:
                bibname2content[bib_name] = content
                bibname_mapping[bib_name] = bib_name
                
        elif index_type == "full_text":
            content = paper_dict.get("md_text", "").strip()
            if content:
                bibname2content[bib_name] = content
                bibname_mapping[bib_name] = bib_name
                
        elif index_type == "segments":
            content = paper_dict.get("md_text", "").strip()
            if content:
                # Split content into segments (by sections or paragraphs)
                segments = split_into_segments(content)
                for i, segment in enumerate(segments):
                    if segment.strip():
                        seg_key = f"{bib_name}_seg_{i}"
                        bibname2content[seg_key] = segment.strip()
                        bibname_mapping[seg_key] = bib_name  # maps back to original bibname
                        
        else:
            raise ValueError(f"Unknown index_type: {index_type}")
    
    return bibname2content, bibname_mapping

def split_into_segments(content: str, max_segment_length: int = 1000) -> List[str]:
    """
    Split content into segments by sections or paragraphs.
    Prefer section breaks (lines starting with #), then paragraph breaks.
    """
    segments = []
    
    # First try to split by sections (markdown headers)
    sections = re.split(r'\n(?=#\s)', content)
    
    for section in sections:
        section = section.strip()
        if not section:
            continue
            
        # If section is too long, split by paragraphs
        if len(section) > max_segment_length:
            paragraphs = re.split(r'\n\s*\n', section)
            current_segment = ""
            
            for para in paragraphs:
                para = para.strip()
                if not para:
                    continue
                    
                # If adding this paragraph would exceed limit, save current segment
                if current_segment and len(current_segment + "\n\n" + para) > max_segment_length:
                    segments.append(current_segment)
                    current_segment = para
                else:
                    if current_segment:
                        current_segment += "\n\n" + para
                    else:
                        current_segment = para
            
            # Add the last segment
            if current_segment:
                segments.append(current_segment)
        else:
            segments.append(section)
    
    return segments

# ---------------- Parse sentences and cites ----------------
SENT_WITH_CITE = re.compile(r'([^.\n]*\\cite\{[^}]+\}[^.\n]*)')

def find_cited_sentences(tex_content: str) -> List[Tuple[str, Tuple[int,int]]]:
    return [(m.group(0), m.span()) for m in SENT_WITH_CITE.finditer(tex_content)]

def extract_keys_from_sentence(sentence: str) -> List[str]:
    keys = []
    for cites in re.findall(r'\\cite\{([^}]+)\}', sentence):
        for k in cites.split(','):
            k = k.strip()
            if k:
                keys.append(k)
    return keys

def clean_claim_text(sentence: str) -> str:
    s = re.sub(r'\\cite\{[^}]+\}', '', sentence)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

def replace_citation_keys(sentence: str, new_keys: List[str]) -> str:
    """Replace all \\cite{...} in the sentence with \\cite{new_keys}."""
    new_cite = r"\cite{" + ",".join(new_keys) + "}"
    return re.sub(r'\\cite\{[^}]+\}', new_cite.replace("\\", "\\\\"), sentence)

def replace_claim_text(sentence: str, new_claim: str) -> str:
    m = re.search(r'(.*?)(\\cite\{[^}]+\})', sentence)
    if not m:
        return new_claim
    return f"{new_claim} {m.group(2)}"

# ---------------- Embedding retrieval (with cache) ----------------
def _hash_corpus(bib_map: Dict[str, str]) -> str:
    h = hashlib.sha1()
    for k in sorted(bib_map.keys()):
        h.update(k.encode('utf-8', errors='ignore'))
        h.update((bib_map[k] or "").encode('utf-8', errors='ignore'))
    return h.hexdigest()

def build_or_load_embeddings(
    bibname2abs: Dict[str, str],
    model_name: str,
    cache_dir: Path,
    use_cache: bool = True,
    index_type: str = "full_text"
) -> Tuple[np.ndarray, List[str], Optional["SentenceTransformer"]]:
    """
    Returns: (emb_matrix [N, D], keys_order, model)
    Cache at cache_dir/.emb_cache_{index_type}_{hash}.npz
    """
    keys = [k for k in bibname2abs.keys()]
    docs = [(bibname2abs.get(k) or "").strip() for k in keys]
    corpus_hash = _hash_corpus(bibname2abs)
    cache_dir.mkdir(parents=True, exist_ok=True)
    cache_path = cache_dir / f".emb_cache_{index_type}_{corpus_hash}.npz"

    model = None
    if _HAVE_ST:
        model = SentenceTransformer(model_name)

    if _HAVE_ST and use_cache and cache_path.exists():
        try:
            data = np.load(cache_path)
            emb = data["emb"]
            cached_keys = json.loads(data["keys_json"].tobytes().decode("utf-8"))
            if cached_keys == keys:
                return emb, keys, model  # cache hit
        except:
            pass  # cache corrupted, rebuild

    if not _HAVE_ST:
        return np.zeros((len(keys), 1), dtype="float32"), keys, None

    emb = model.encode(docs, show_progress_bar=False, normalize_embeddings=True)
    if use_cache:
        np.savez_compressed(
            cache_path,
            emb=emb.astype("float32"),
            keys_json=np.frombuffer(json.dumps(keys).encode("utf-8"), dtype="uint8")
        )
    return emb, keys, model

def embed_query(text: str, model: Optional["SentenceTransformer"]) -> np.ndarray:
    if not _HAVE_ST or model is None:
        return np.zeros((1, 1), dtype="float32")
    q = model.encode([text], show_progress_bar=False, normalize_embeddings=True)
    return q.astype("float32")

def cosine_topk(query_vec: np.ndarray, emb_matrix: np.ndarray, top_k: int = 30) -> np.ndarray:
    sims = (emb_matrix @ query_vec.T).reshape(-1)  
    idx = np.argsort(-sims)[:top_k]
    return idx

# ---------------- TF-IDF retrieval ----------------
def tfidf_retrieve(query: str, bibname2abs: Dict[str,str], top_k: int = 30) -> List[Tuple[str, str, float]]:
    keys = list(bibname2abs.keys())
    docs = [(bibname2abs.get(k) or "").strip() for k in keys]
    if not _HAVE_SK:
        return [(keys[i], docs[i], 0.0) for i in range(min(top_k, len(keys)))]
    
    try:
        vect = TfidfVectorizer(max_features=20000, ngram_range=(1,2), token_pattern=r"(?u)\\b\\w+\\b")
        X = vect.fit_transform([query] + docs)
        if X.shape[0] < 2 or X[0].nnz == 0:
            return [(keys[i], docs[i], 0.0) for i in range(min(top_k, len(keys)))]
        try:
            sims = cosine_similarity(X[0], X[1:]).ravel()
        except Exception:
            sims = (X[0] @ X[1:].T).toarray().ravel()
        import numpy as np
        idx = np.argsort(-sims)[:top_k]
        return [(keys[i], docs[i], float(sims[i])) for i in idx]
    except Exception:
        return [(keys[i], docs[i], 0.0) for i in range(min(top_k, len(keys)))]

# ---------------- NLI rerank function ----------------
def nli_rerank(claim: str, ranked: List[Tuple[str, str, float]], max_keep: int = 10) -> List[Tuple[str, str, float]]:
    """Keep NLI=True candidates, preserving order; fallback to top max_keep if none."""
    kept = []
    for k, txt, sc in ranked:
        try:
            if nli(claim, txt):
                kept.append((k, txt, sc))
                if len(kept) >= max_keep:
                    break
        except Exception:
            continue
    return kept if kept else ranked[:max_keep]

# ---------------- Sliding-window NLI ----------------
def sliding_nli_pick(
    claim: str,
    cand_keys: List[str],
    cand_texts: List[str],
    bibname_mapping: Dict[str, str],  # maps segment keys back to bibnames
    window: int = 3,
    max_windows: int = 5,
    use_nli_rerank: bool = False
) -> List[str]:
    picked_bibnames: List[str] = []
    total = len(cand_keys)
    
    for w in range(max_windows):
        s, e = w * window, min((w + 1) * window, total)
        if s >= e:
            break
        for i in range(s, e):
            k, txt = cand_keys[i], cand_texts[i]
            try:
                if txt and nli(claim, txt):
                    # Convert segment key back to original bibname
                    original_bibname = bibname_mapping.get(k, k)
                    if original_bibname not in picked_bibnames:
                        picked_bibnames.append(original_bibname)
                        if len(picked_bibnames) >= 3:
                            return picked_bibnames
            except Exception:
                continue
    
    if not picked_bibnames:
        seen = set()
        for k in cand_keys[:min(3, total)]:
            original_bibname = bibname_mapping.get(k, k)
            if original_bibname not in seen:
                picked_bibnames.append(original_bibname)
                seen.add(original_bibname)
                if len(picked_bibnames) >= 3:
                    break
    
    return picked_bibnames

def choose_cite_keys_hybrid_enhanced(
    claim: str,
    orig_keys: List[str],
    bibname2abs: Dict[str, str], 
    bibname_mapping: Dict[str, str], 
    keep_supported: bool,
    retrieval: str,            # 'auto' | 'embed' | 'tfidf'
    topk: int,
    nli_window_size: int,
    max_windows: int,
    emb_matrix: Optional[np.ndarray],
    emb_keys: Optional[List[str]],
    emb_model: Optional["SentenceTransformer"], 
    use_nli_rerank: bool = False,
) -> List[str]:
    # Step 1: keep original supported (check against full content if available)
    if keep_supported and orig_keys:
        kept = []
        for k in orig_keys:
            if k in bibname2abs:
                txt = bibname2abs[k]
            else:
                # Look for segments of this paper
                segment_keys = [sk for sk in bibname2abs.keys() if bibname_mapping.get(sk) == k]
                txt = ""
                if segment_keys:
                    # Use the first segment or concatenate all segments
                    txt = bibname2abs[segment_keys[0]]
            
            try:
                if txt and nli(claim, txt):
                    kept.append(k)
            except Exception:
                pass
        if kept:
            return kept[:3]

    # Step 2: candidates via chosen retrieval
    use_embed = (retrieval == "embed") or (retrieval == "auto" and _HAVE_ST and emb_matrix is not None and emb_keys is not None)
    cand_keys: List[str] = []
    cand_texts: List[str] = []

    if use_embed:
        if not (_HAVE_ST and emb_matrix is not None and emb_keys is not None and emb_model is not None):
            # fallback 
            pairs = tfidf_retrieve(claim, bibname2abs, top_k=topk)
            for k, txt, _ in pairs:
                cand_keys.append(k)
                cand_texts.append(txt)
        else:
            q_vec = embed_query(claim, emb_model) 
            idx = cosine_topk(q_vec, emb_matrix, top_k=topk)  
            for i in idx:
                k = emb_keys[int(i)]
                cand_keys.append(k)
                cand_texts.append(bibname2abs.get(k, "") or "")
    else:
        pairs = tfidf_retrieve(claim, bibname2abs, top_k=topk)
        # Apply NLI rerank if requested
        if use_nli_rerank:
            pairs = nli_rerank(claim, pairs, max_keep=min(10, topk))
        for k, txt, _ in pairs:
            cand_keys.append(k)
            cand_texts.append(txt)

    # Step 3: sliding-window NLI pick (with bibname mapping)
    return sliding_nli_pick(
        claim=claim,
        cand_keys=cand_keys,
        cand_texts=cand_texts,
        bibname_mapping=bibname_mapping,
        window=nli_window_size,
        max_windows=max_windows,
        use_nli_rerank=use_nli_rerank
    )

# ---------------- LLM dialogs (unchanged) ----------------
def ask_llm_for_error_reason(claim: str, ctx_snippets: List[str]) -> str:
    joined = "\n\n---\n\n".join(ctx_snippets[:8]) if ctx_snippets else "(no sources text found)"
    prompt = f"""You are an NLI judge and scientific editor.

Explain WHY the claim is NOT supported by the provided SOURCES/CONTEXT.
- Be specific (3–6 bullet points): missing evidence, scope mismatch, contradictions, over-claiming, etc.

Claim:
{claim}

SOURCES/CONTEXT:
{joined}
"""
    return remote_chat(prompt).strip()

def ask_llm_for_how_to_fix(claim: str, ctx_snippets: List[str]) -> Dict:
    joined = "\n\n---\n\n".join(ctx_snippets[:8]) if ctx_snippets else "(no sources text found)"
    prompt = f"""You are an NLI judge and scientific editor.

Propose a FIX that makes the claim fully supported by the SOURCES/CONTEXT.
Return JSON ONLY:
- "rewrite": ONE single sentence strictly supported by given SOURCES/CONTEXT.
- "style_notes": 2-4 brief tips to keep tone factual and precise.
- "alt_phrasings": 2 single-sentence alternatives, also strictly supported.

Claim:
{claim}

SOURCES/CONTEXT:
{joined}
"""
    raw = remote_chat(prompt).strip()
    try:
        return json.loads(raw)
    except Exception:
        return {"rewrite": raw, "style_notes": [], "alt_phrasings": []}

# ---------------- Progress saving/loading ----------------
def save_progress(session_records: List[Dict], progress_path: Path, current_sentence_idx: int):
    """Save current progress to allow resuming"""
    progress_data = {
        "current_sentence_idx": current_sentence_idx,
        "session_records": session_records,
        "timestamp": time.time()
    }
    with open(progress_path, "w", encoding="utf-8") as f:
        json.dump(progress_data, f, ensure_ascii=False, indent=2)

def load_progress(progress_path: Path) -> Tuple[int, List[Dict]]:
    """Load progress from file, return (start_sentence_idx, session_records)"""
    if not progress_path.exists():
        return 0, []
    
    try:
        with open(progress_path, "r", encoding="utf-8") as f:
            progress_data = json.load(f)
        return progress_data.get("current_sentence_idx", 0), progress_data.get("session_records", [])
    except Exception as e:
        print(f"[WARN] Failed to load progress: {e}")
        return 0, []

# ---------------- Main ----------------
def main(args):
    tex_path = Path(args.tex)
    ref_dir  = Path(args.refdir)
    out_tex  = Path(args.outtex)
    session_json = Path(args.session_json)

    print(f"[INFO] Using index type: {args.index_type}")
    print(f"[INFO] Using retrieval method: {args.retrieval}")

    # Load refs with enhanced multi-level indexing
    bibname2abs, bibname_mapping = get_refs_enhanced(ref_dir, args.index_type)
    print(f"[INFO] Loaded {len(bibname2abs)} indexed items from {len(set(bibname_mapping.values()))} unique papers")
    
    original_content = load_file_as_string(tex_path)

    # Embedding index (only if possibly used)
    emb_matrix = None
    emb_keys: Optional[List[str]] = None
    emb_model = None
    will_use_embed = (args.retrieval == "embed") or (args.retrieval == "auto" and _HAVE_ST)
    if will_use_embed:
        try:
            cache_dir = ref_dir if args.use_cache else Path(".")
            emb_matrix, emb_keys, emb_model = build_or_load_embeddings(
                bibname2abs, args.embed_model, cache_dir, use_cache=args.use_cache, index_type=args.index_type
            )
            print(f"[INFO] Built embedding index with {emb_matrix.shape[0]} items")
        except Exception as e:
            print(f"[WARN] Embedding build failed ({e}); fallback to TF-IDF.")
            emb_matrix, emb_keys, emb_model = None, None, None

    # Find cited sentences and process in reverse order (safe span replacement)
    sents = find_cited_sentences(original_content)
    sents_sorted = sorted(sents, key=lambda x: x[1][0], reverse=True)
    print(f"[INFO] Found {len(sents_sorted)} sentences with citations")

    buf = list(original_content)
    session_records = []
    
    # Progress handling
    start_sentence_idx = 0
    progress_path = None
    if args.resume:
        progress_path = Path(args.resume)
        start_sentence_idx, session_records = load_progress(progress_path)
        print(f"[INFO] Resuming from sentence {start_sentence_idx + 1}")
    elif args.save_progress:
        progress_path = ref_dir / f".progress_{tex_path.stem}.json"
        print(f"[INFO] Progress will be saved to: {progress_path}")

    for i, (sent, (st, ed)) in enumerate(sents_sorted):
        # Skip sentences that were already processed
        if i < start_sentence_idx:
            continue
            
        print(f"\n[PROGRESS] Processing sentence {i+1}/{len(sents_sorted)}")
        print(f"[DEBUG] Sentence: {sent[:100]}..." if len(sent) > 100 else f"[DEBUG] Sentence: {sent}")
        
        orig_keys = extract_keys_from_sentence(sent)
        claim = clean_claim_text(sent)
        print(f"[DEBUG] Original keys: {orig_keys}")
        print(f"[DEBUG] Claim: {claim[:150]}..." if len(claim) > 150 else f"[DEBUG] Claim: {claim}")

        # NLI on original keys - need to map them to our index
        tf = []
        for k in orig_keys:
            # Find content for this key in our index
            txt = ""
            if k in bibname2abs:
                txt = bibname2abs[k]
            else:
                # Look for segments of this paper
                segment_keys = [sk for sk in bibname2abs.keys() if bibname_mapping.get(sk) == k]
                if segment_keys:
                    txt = bibname2abs[segment_keys[0]]  # Use first segment for NLI check
            
            try:
                ok = nli(claim, txt) if txt else False
                tf.append(ok)
                print(f"[NLI] Key {k}: {ok}")
            except Exception as e:
                print(f"[ERROR] NLI failed for key {k}: {e}")
                tf.append(False)
        
        t_count = tf.count(True)
        print(f"[INFO] NLI results: {t_count}/{len(orig_keys)} keys support the claim")

        # Already supported and not forcing: optionally prune, then skip
        if t_count > 0 and not args.force_all:
            print("[INFO] Claim is already supported, checking if need to prune...")
            if args.prune_unsupported_refs and len(orig_keys) > 1:
                kept_keys = [k for k, ok in zip(orig_keys, tf) if ok]
                if kept_keys and len(kept_keys) < len(orig_keys):
                    print(f"[PRUNE] Keeping {len(kept_keys)}/{len(orig_keys)} supported keys: {kept_keys}")
                    final_sentence = replace_citation_keys(sent, kept_keys)
                    buf[st:ed] = list(final_sentence)
                    session_records.append({
                        "orig_sentence": sent,
                        "orig_keys": orig_keys,
                        "orig_nli": tf,
                        "diagnosis": "(auto-prune unsupported refs)",
                        "proposed_fix": {},
                        "chosen_action": "prune_refs",
                        "selected_keys": kept_keys,
                        "final_sentence": final_sentence
                    })
                else:
                    print("[INFO] No pruning needed")
            else:
                print("[INFO] Skipping sentence (already supported)")
            continue

        # Unsupported or forced: ask LLM for diagnosis & fix
        print("[INFO] Claim is not supported, asking LLM for diagnosis and fix...")
        src_snippets = []
        for k in orig_keys:
            if k in bibname2abs:
                src_snippets.append(bibname2abs[k])
            else:
                segment_keys = [sk for sk in bibname2abs.keys() if bibname_mapping.get(sk) == k]
                for sk in segment_keys[:2]:  # Use first 2 segments
                    src_snippets.append(bibname2abs[sk])
        
        try:
            diagnosis = ask_llm_for_error_reason(claim, src_snippets)
            print(f"[LLM] Diagnosis received: {diagnosis[:100]}..." if len(diagnosis) > 100 else f"[LLM] Diagnosis: {diagnosis}")
        except Exception as e:
            print(f"[ERROR] Failed to get diagnosis: {e}")
            diagnosis = f"Error getting diagnosis: {e}"
        
        try:
            fix_json = ask_llm_for_how_to_fix(claim, src_snippets)
            print(f"[LLM] Fix proposal received")
        except Exception as e:
            print(f"[ERROR] Failed to get fix: {e}")
            fix_json = {"rewrite": claim, "style_notes": [], "alt_phrasings": []}
        
        rewrite = fix_json.get("rewrite") or claim

        chosen_action = None
        final_sentence = sent
        selected_keys  = orig_keys

        def apply_replace(sentence_text: str, new_claim_text: Optional[str], new_keys: Optional[List[str]]) -> str:
            s = replace_claim_text(sentence_text, new_claim_text) if new_claim_text is not None else sentence_text
            if new_keys is not None and len(new_keys) > 0:
                s = replace_citation_keys(s, new_keys)
            return s

        # Strategy handling with enhanced retrieval
        print(f"[STRATEGY] Using strategy: {args.strategy}")
        if args.strategy == "ref_first":
            print("[STRATEGY] Trying to find better citation keys first...")
            try:
                cand_keys = choose_cite_keys_hybrid_enhanced(
                    claim=claim, orig_keys=orig_keys, bibname2abs=bibname2abs, bibname_mapping=bibname_mapping,
                    keep_supported=args.keep_supported,
                    retrieval=args.retrieval, topk=args.topk,
                    nli_window_size=args.nli_window_size, max_windows=args.max_windows,
                    emb_matrix=emb_matrix, emb_keys=emb_keys, emb_model=emb_model,
                    use_nli_rerank=args.use_nli_rerank
                )
                print(f"[STRATEGY] Found candidate keys: {cand_keys}")
            except Exception as e:
                print(f"[ERROR] Failed to choose citation keys: {e}")
                cand_keys = []
            
            if cand_keys:
                final_sentence = apply_replace(sent, None, cand_keys)
                # Verify new keys support the claim
                ok_new = False
                for k in cand_keys:
                    if k in bibname2abs:
                        txt = bibname2abs[k]
                    else:
                        segment_keys = [sk for sk in bibname2abs.keys() if bibname_mapping.get(sk) == k]
                        txt = bibname2abs[segment_keys[0]] if segment_keys else ""
                    try:
                        if txt and nli(claim, txt):
                            ok_new = True
                            break
                    except:
                        continue
                
                if ok_new or args.ref_only_fallback_ok:
                    selected_keys = cand_keys
                    chosen_action = "update_refs"
                else:
                    cand_keys2 = choose_cite_keys_hybrid_enhanced(
                        claim=rewrite, orig_keys=orig_keys, bibname2abs=bibname2abs, bibname_mapping=bibname_mapping,
                        keep_supported=args.keep_supported,
                        retrieval=args.retrieval, topk=args.topk,
                        nli_window_size=args.nli_window_size, max_windows=args.max_windows,
                        emb_matrix=emb_matrix, emb_keys=emb_keys, emb_model=emb_model,
                        use_nli_rerank=args.use_nli_rerank
                    )
                    selected_keys = cand_keys2 if cand_keys2 else cand_keys
                    final_sentence = apply_replace(sent, rewrite, selected_keys)
                    chosen_action = "update_refs_then_claim"
            else:
                cand_keys2 = choose_cite_keys_hybrid_enhanced(
                    claim=rewrite, orig_keys=orig_keys, bibname2abs=bibname2abs, bibname_mapping=bibname_mapping,
                    keep_supported=args.keep_supported,
                    retrieval=args.retrieval, topk=args.topk,
                    nli_window_size=args.nli_window_size, max_windows=args.max_windows,
                    emb_matrix=emb_matrix, emb_keys=emb_keys, emb_model=emb_model,
                    use_nli_rerank=args.use_nli_rerank
                )
                selected_keys = cand_keys2 if cand_keys2 else orig_keys
                final_sentence = apply_replace(sent, rewrite, selected_keys)
                chosen_action = "update_claim_only"

        elif args.strategy == "claim_first":
            cand_keys = choose_cite_keys_hybrid_enhanced(
                claim=rewrite, orig_keys=orig_keys, bibname2abs=bibname2abs, bibname_mapping=bibname_mapping,
                keep_supported=args.keep_supported,
                retrieval=args.retrieval, topk=args.topk,
                nli_window_size=args.nli_window_size, max_windows=args.max_windows,
                emb_matrix=emb_matrix, emb_keys=emb_keys, emb_model=emb_model,
                use_nli_rerank=args.use_nli_rerank
            )
            selected_keys = cand_keys if cand_keys else orig_keys
            final_sentence = apply_replace(sent, rewrite, selected_keys)
            chosen_action = "update_claim_then_refs" if cand_keys else "update_claim_only"
            if not cand_keys and args.ref_fallback_if_claim_fail:
                cand_keys2 = choose_cite_keys_hybrid_enhanced(
                    claim=claim, orig_keys=orig_keys, bibname2abs=bibname2abs, bibname_mapping=bibname_mapping,
                    keep_supported=args.keep_supported,
                    retrieval=args.retrieval, topk=args.topk,
                    nli_window_size=args.nli_window_size, max_windows=args.max_windows,
                    emb_matrix=emb_matrix, emb_keys=emb_keys, emb_model=emb_model,
                    use_nli_rerank=args.use_nli_rerank
                )
                if cand_keys2:
                    final_sentence = apply_replace(sent, None, cand_keys2)
                    selected_keys = cand_keys2
                    chosen_action = "fallback_update_refs"

        elif args.strategy == "ref_only":
            cand_keys = choose_cite_keys_hybrid_enhanced(
                claim=claim, orig_keys=orig_keys, bibname2abs=bibname2abs, bibname_mapping=bibname_mapping,
                keep_supported=args.keep_supported,
                retrieval=args.retrieval, topk=args.topk,
                nli_window_size=args.nli_window_size, max_windows=args.max_windows,
                emb_matrix=emb_matrix, emb_keys=emb_keys, emb_model=emb_model,
                use_nli_rerank=args.use_nli_rerank
            )
            if cand_keys:
                final_sentence = apply_replace(sent, None, cand_keys)
                selected_keys = cand_keys
                chosen_action = "update_refs"
            else:
                chosen_action = "no_change_refs_not_found"

        elif args.strategy == "claim_only":
            final_sentence = apply_replace(sent, rewrite, None)
            chosen_action = "update_claim_only"

        else:  # fallback default = ref_first
            cand_keys = choose_cite_keys_hybrid_enhanced(
                claim=claim, orig_keys=orig_keys, bibname2abs=bibname2abs, bibname_mapping=bibname_mapping,
                keep_supported=args.keep_supported,
                retrieval=args.retrieval, topk=args.topk,
                nli_window_size=args.nli_window_size, max_windows=args.max_windows,
                emb_matrix=emb_matrix, emb_keys=emb_keys, emb_model=emb_model,
                use_nli_rerank=args.use_nli_rerank
            )
            if cand_keys:
                final_sentence = apply_replace(sent, None, cand_keys)
                selected_keys = cand_keys
                chosen_action = "update_refs"
            else:
                final_sentence = apply_replace(sent, rewrite, orig_keys)
                chosen_action = "update_claim_only"

        buf[st:ed] = list(final_sentence)
        session_records.append({
            "orig_sentence": sent,
            "orig_keys": orig_keys,
            "orig_nli": tf,
            "diagnosis": diagnosis,
            "proposed_fix": fix_json,
            "chosen_action": chosen_action,
            "selected_keys": selected_keys,
            "final_sentence": final_sentence
        })
        print(f"[INFO] Applied action: {chosen_action}")
        
        # Save progress periodically if requested
        if args.save_progress and progress_path and (i + 1) % 10 == 0:
            save_progress(session_records, progress_path, i + 1)
            print(f"[PROGRESS] Saved progress at sentence {i + 1}")

    fixed_text = "".join(buf)
    with open(out_tex, "w", encoding="utf-8") as f:
        f.write(fixed_text)

    with open(session_json, "w", encoding="utf-8") as f:
        json.dump({
            "file": str(tex_path),
            "strategy": args.strategy,
            "keep_supported": args.keep_supported,
            "retrieval": args.retrieval,
            "index_type": args.index_type,  
            "topk": args.topk,
            "nli_window_size": args.nli_window_size,
            "max_windows": args.max_windows,
            "embed_model": args.embed_model,
            "use_cache": args.use_cache,
            "use_nli_rerank": args.use_nli_rerank,  # New field
            "prune_unsupported_refs": args.prune_unsupported_refs,
            "repairs": session_records
        }, f, ensure_ascii=False, indent=2)

    print(f"[OK] New tex written: {out_tex}")
    print(f"[OK] Session log JSON: {session_json}")
    print(f"[OK] Total modified/pruned sentences: {len(session_records)}")
    
    # Clean up progress file on successful completion
    if args.save_progress and progress_path and progress_path.exists():
        progress_path.unlink()
        print(f"[INFO] Removed progress file: {progress_path}")

    if args.emit_diff:
        diff_path = Path(args.diff_path) if args.diff_path else Path(str(out_tex) + ".patch")
        diff_lines = difflib.unified_diff(
            load_file_as_string(tex_path).splitlines(keepends=True),
            fixed_text.splitlines(keepends=True),
            fromfile=str(tex_path),
            tofile=str(out_tex),
            lineterm=""
        )
        with open(diff_path, "w", encoding="utf-8") as df:
            df.writelines(diff_lines)
        print(f"[OK] Diff patch written: {diff_path}")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--tex", required=True, help="Input LaTeX main .tex")
    ap.add_argument("--refdir", required=True, help="Directory of reference JSONs (bib_name, md_text)")
    ap.add_argument("--outtex", required=True, help="Output fixed .tex path")
    ap.add_argument("--session_json", required=True, help="Output session/fix log JSON")

    ap.add_argument("--strategy", choices=["ref_first","claim_first","ref_only","claim_only"], default="ref_first")
    ap.add_argument("--keep_supported", action="store_true", help="Keep original NLI=True keys when possible")
    ap.add_argument("--prune_unsupported_refs", action="store_true", help="If claim already supported, drop NLI=False refs")

    # Multi-level indexing option
    ap.add_argument("--index_type", choices=["abstract", "full_text", "segments"], default="full_text",
                    help="How to build index: abstract only, full text, or split into segments")

    # Retrieval choice and params
    ap.add_argument("--retrieval", choices=["auto","embed","tfidf"], default="auto",
                    help="Which retrieval method to use (auto: prefer embedding if available)")
    ap.add_argument("--topk", type=int, default=30, help="Candidate count from retrieval")
    ap.add_argument("--nli_window_size", type=int, default=3, help="How many candidates to NLI-check per window")
    ap.add_argument("--max_windows", type=int, default=5, help="Max windows to check (total NLI <= window*max)")

    # Embedding settings
    ap.add_argument("--embed_model", type=str, default="sentence-transformers/all-MiniLM-L6-v2",
                    help="Sentence-Transformers model name (only used if retrieval=embed/auto and available)")
    ap.add_argument("--use_cache", action="store_true", help="Cache reference embeddings under refdir")
    ap.add_argument("--use_nli_rerank", action="store_true",
                    help="Rerank/prune TF-IDF candidates with NLI to keep only supported ones")

    # Strategy auxiliaries
    ap.add_argument("--force_all", action="store_true", help="Even if supported, still run full fix flow")
    ap.add_argument("--ref_only_fallback_ok", action="store_true", help="In ref_first, accept refs-only even if imperfect")
    ap.add_argument("--ref_fallback_if_claim_fail", action="store_true", help="In claim_first, fallback to refs-only once")

    # Progress and resuming
    ap.add_argument("--save_progress", action="store_true", help="Save progress periodically to allow resuming")
    ap.add_argument("--resume", type=str, default="", help="Resume from progress file (path to progress JSON)")

    # Diff export
    ap.add_argument("--emit_diff", action="store_true", help="Export unified diff patch")
    ap.add_argument("--diff_path", type=str, default="", help="Path for diff patch (default: outtex + .patch)")
    
    args = ap.parse_args()
    main(args)