# ==========================================================
# Clean pipeline: DP-aware keyword extraction + masked fill-in
# ==========================================================
from __future__ import annotations

import os
import json
import string
from pathlib import Path
from typing import List, Dict, Any, Optional, Union, Iterable
from collections import Counter

import numpy as np
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

# --- Project imports (assumed available in your env) ---
from temp import load_jsonl_with_json
from src.gen_util import *
from src.utils import *
from src.fill_in import *
from src.pearl import *
from src.jointEM import joint
from src.confidenceGap import *
from account import find_best_epsilon
from tqdm import tqdm 
from src.fill_in import * 

# ==========================================================
# Utilities
# ==========================================================
def set_seed(seed: int = 42) -> None:
    """Set global random seeds."""
    import random
    random.seed(seed)
    np.random.seed(seed)

def ensure_nltk():
    """Download minimal NLTK resources if missing."""
    try:
        nltk.data.find("tokenizers/punkt")
    except LookupError:
        nltk.download("punkt")
    try:
        nltk.data.find("corpora/stopwords")
    except LookupError:
        nltk.download("stopwords")

def load_index_and_mapping(index_path: Path, mapping_path: Path) -> tuple[dict, dict]:
    with index_path.open("r", encoding="utf-8") as f:
        index = json.load(f)
    with mapping_path.open("r", encoding="utf-8") as f:
        mapping = json.load(f)
    return index, mapping

def normalize_text_tokens(texts: Iterable[str]) -> List[str]:
    """Tokenize + normalize to lowercase; filter punctuation and stopwords."""
    stop_set = set(stopwords.words("english"))
    tokens: List[str] = []
    for t in texts:
        # NOTE: word_tokenize keeps punctuation as separate tokens → filter later
        toks = [w.lower() for w in word_tokenize(t)]
        toks = [w for w in toks if w not in stop_set and w not in string.punctuation]
        tokens.extend(toks)
    return tokens

def filtering(result, eps, method='top-k', k=2, thr_b=-0.2, thr_u=0.5):
    text, CG_list = result['text'], result['CG']
    result = sentence_level_cg(text, CG_list, min_length=8)
    sentences = [row['sentence'] for row in result]
    cg_values = [row['avg_cg'] for row in result]
    if method == 'top-k':
        top_k_idx, noisy_g = dp_topk_gumbel_one_shot(cg_values, k=k, epsilon=eps, sensitivity=0.08, seed=10, mode='top', return_noisy=True)
        bottom_k_idx, noisy_g = dp_topk_gumbel_one_shot(cg_values, k=k, epsilon=eps, sensitivity=0.08, seed=10, mode='bottom', return_noisy=True)
    elif method == 'svt':
        top_k_idx = svt_sparse(cg_values, threshold=thr_b, sensitivity=0.08, epsilon=eps, c=k)
        bottom_k_idx = svt_below_threshold(cg_values, threshold=thr_b, sensitivity=0.08, epsilon=eps, c=k)
    else:
        raise NotImplementedError

    
    overlap = set(top_k_idx) & set(bottom_k_idx)

    if overlap
        top_k_idx = [x for x in top_k_idx if x not in overlap]
        bottom_k_idx = list(set(bottom_k_idx) | overlap)

    for idx, s in enumerate(sentences):
        if idx in top_k_idx or idx in bottom_k_idx:
            sentences[idx] = "[MASK]"
    
    return sentences, top_k_idx, bottom_k_idx


def documents_to_keywords(documents: List[str], k: int, eps: float) -> List[str]:
    """
    Extract top-k keywords with DP selection via `joint`.
    - Tokenizes documents, counts 1-grams, filters stopwords/punct.
    - Applies `joint` to choose DP top-k indices (neighbor_type=1 as in original).
    - Returns ordered keywords (by the DP-selected indices).
    """
    if not documents:
        return []

    tokens = normalize_text_tokens(documents)
    if not tokens:
        return []

    counts = Counter(tokens)
    # Sorted by frequency desc, then lexicographically for determinism
    items_sorted = sorted(counts.items(), key=lambda x: (-x[1], x[0]))
    freqs = np.array([c for _, c in items_sorted], dtype=np.int64)

    if len(freqs) == 0:
        return []

    k = min(k, len(freqs))
    sel_idx = joint(freqs, k, epsilon=eps, neighbor_type=1)
    # Convert to list of strings in the selected order
    selected = [items_sorted[i][0] for i in sel_idx]
    return selected

def lines_to_slots(lines: List[str], slot_indices: List[int]) -> List[str]:
    """
    Pad/truncate parsed lines to match slot_indices length.
    Returns a list of length == len(slot_indices).
    """
    lines = [ln.strip() for ln in lines if ln.strip()]
    n = len(slot_indices)
    if len(lines) < n:
        lines = lines + [""] * (n - len(lines))
    elif len(lines) > n:
        lines = lines[:n]
    return lines

def apply_lines(masked: List[str], slot_indices: List[int], lines: List[str]) -> None:
    """Write `lines` into `masked` at positions from `slot_indices`."""
    for i, txt in enumerate(lines):
        masked[slot_indices[i]] = txt

def fill_masks(
    client,
    masked: List[str],
    query, 
    mask_indices: List[int],
    *,
    keywords: Optional[List[str]] = None,
    group: str = "big",
    tone: str = "formal",
    max_tokens: int = 1024,
    temperature: float = .8,
) -> List[str]:
    """
    Build prompts (with/without keywords), run completion, parse & map results.
    Returns updated `masked`.
    """
    prompts = build_multi_fillin_prompts_with_keywords(
        sentences=masked,
        mask_indices=mask_indices,
        query=query, 
        keywords_by_freq=keywords,
        group=group,
        style_hints={"tone": tone},
    )

    prompt = prompts["user_prompts"][0]
    print(prompt)
    assert 1== 2
    output_text, _ = client.complete(prompt, max_tokens, temperature, silent=True)

    #parsed = output_text.split("\n")
    #mapped = lines_to_slots(parsed, mask_indices)
    #apply_lines(masked, mask_indices, mapped)
    return output_text

# ==========================================================
# Main
# ==========================================================
def main():
    # Config
    # ==========================================================
    SEED = 42
    EPS_TOTAL =                      # overall privacy target (if needed)
    EPS_STEP = 
    EPS_NEW_INIT = 
    NITER_CONSUMED =
    DELTA = 1 / 1024
    PROB = 10 / 1024

    KEYWORDS_EPS =                   # epsilon for DP top-k keyword selection
    FILTERING_EPS = 

    #INDEX_PATH = Path("corpus/index_rank.json")
    #INDEX_MAPPING_PATH = Path("corpus/index_mapping.json")
    #BASE_PATH = Path("corpus")
    
    INDEX_PATH = Path("corpus_financial_chunk/index_rank_new2.json")
    INDEX_MAPPING_PATH = Path("corpus_financial_chunk/index_mapping.json")
    BASE_PATH = Path("corpus_financial_chunk")    

    k=6
    INPUT_RESULTS = Path("qwen_medical_result_{eps=6}.jsonl")
    OUTPUT_PATH = Path(f"qwen_medical_result_eps={EPS_TOTAL}_filled_k={k}.jsonl")
    print(OUTPUT_PATH)
    TOP_K_KEYWORDS = 40

    CLIENT_MODEL = "gpt-4o"
    GEN_MAX_TOKENS = 1024
    GEN_TEMP = 0.8
    TONE = "formal"
    nc=250

    set_seed(SEED)
    ensure_nltk()

    # Privacy params (kept for reference; not actively re-optimized here)
    privacy_params = {
        "eps": EPS_STEP,
        "sigma": 1,
        "prob": PROB,
        "niter": NITER_CONSUMED,
        "eps_new": EPS_NEW_INIT,
        "monotone": False,
    }

    # Load index/mapping and build jobs
    index, mapping = load_index_and_mapping(INDEX_PATH, INDEX_MAPPING_PATH)
    qs, jobs, gts = create_jobs(str(INDEX_PATH), str(INDEX_MAPPING_PATH), base_path=str(BASE_PATH))

    # Load LLM outputs to post-process
    outputs = load_jsonl_with_json(str(INPUT_RESULTS))

    # Initialize client
    client = OpenAIClient(CLIENT_MODEL)

    # Write results incrementally
    OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
    with OUTPUT_PATH.open("w", encoding="utf-8") as f:
        for i, output in enumerate(tqdm(outputs[:nc])):
            documents = jobs[i][:4] if i < len(jobs) else []
            masked, top_k, bottom_k = filtering(output, k=k, eps=FILTERING_EPS, method='top-k')

            # Keywords from documents (DP top-k)
            keywords = documents_to_keywords(documents, k=TOP_K_KEYWORDS, eps=KEYWORDS_EPS)

            # Remove sentences from `masked` that correspond to `top_k` indexes
            if top_k:
                masked = [m for idx, m in enumerate(masked) if idx not in top_k]

                # Update bottom_k to account for shifted indices after removing top_k
                bottom_k = [idx for idx in bottom_k if idx not in top_k]
                bottom_k = [idx - sum(1 for t in top_k if t < idx) for idx in bottom_k]

            # Pass 1: fill less-confident slots with keywords (group="big")
            if bottom_k:  # only if non-empty
                masked = fill_masks(
                    client,
                    masked,
                    qs[i], 
                    bottom_k,
                    keywords=keywords,
                    group="big",
                    tone=TONE,
                )
            else:
                masked = " ".join(masked)

            final_output = masked d)
    
            row = {
                "output": final_output,           # fixed key name
                "documents": documents,
                "query": qs[i] if i < len(qs) else None,
                "gt": gts[i] if i < len(gts) else None,
            }
        
            f.write(json.dumps(row, ensure_ascii=False) + "\n")
            f.flush()  # periodic flush for safety

if __name__ == "__main__":
    main()
