#!/usr/bin/env python3
# augment_v5_semantic_lenaware.py (v5.4)
# Length-aware, semantic-aware augmentation tuned for small corpora:
# - Per-seed production caps (progress guaranteed)
# - Per-seed duplicate caps (controlled reuse, avoids starvation)
# - Hard acceptance guard per item (prevents infinite churn)
# - Adaptive relaxation & quota rescue retained
# - Robust numpy.choice usage
#
# Safe defaults keep semantic + length fidelity while completing n on tiny seeds.

import os, re, csv, argparse, random, math
from typing import List, Tuple, Optional, Dict, Any
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.mixture import GaussianMixture
from sklearn.metrics.pairwise import cosine_similarity

# ----------------------------- IO -----------------------------
def load_csv_utf8(path: str) -> pd.DataFrame:
    for enc in ("utf-8", "utf-8-sig", "latin-1"):
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception:
            continue
    return pd.read_csv(path, engine="python")

def save_csv_utf8(df: pd.DataFrame, path: str) -> None:
    df.to_csv(
        path,
        index=False,
        encoding="utf-8-sig",
        quoting=csv.QUOTE_ALL,
        escapechar="\\",
        lineterminator="\n",
    )

def detect_text_col(df: pd.DataFrame, preferred: str) -> str:
    if preferred in df.columns:
        return preferred
    for c in ["text", "response_text", "resume_text", "content", "body"]:
        if c in df.columns:
            return c
    obj_cols = [c for c in df.columns if pd.api.types.is_object_dtype(df[c])]
    return obj_cols[0] if obj_cols else df.columns[0]

def detect_id_col(kind: str, df: pd.DataFrame) -> Optional[str]:
    candidates = ["response_id", "id", "uid", "text_id"] if kind == "therapy" else ["resume_id", "id", "uid", "text_id"]
    for c in candidates:
        if c in df.columns:
            return c
    return None

def uniq_key(s: str) -> str:
    return re.sub(r"\s+", " ", s.strip().lower())

# ----------------------------- Cleaning -----------------------------
def clean_text(s: str) -> str:
    t = str(s).replace("\r", "\n")
    t = t.replace("\x97", "-").replace("\x96", "-").replace("\x95", "-").replace("\x92", "'")
    t = t.replace("\x8a", " ").replace("\x85", " ")
    t = re.sub(r"[^\x20-\x7E\n\t]", " ", t)
    t = re.sub(r"[ \t]+", " ", t)
    t = re.sub(r"\n{3,}", "\n\n", t)
    return t.strip()

# ----------------------------- Embedding & Clustering -----------------------------
def build_embedder_hash(corpus: pd.Series) -> Tuple[HashingVectorizer, TruncatedSVD, np.ndarray]:
    vec = HashingVectorizer(
        ngram_range=(1, 2),
        analyzer="word",
        alternate_sign=False,
        n_features=4096,
        norm="l2"
    )
    X = vec.transform(corpus.astype(str).fillna(""))
    svd = TruncatedSVD(n_components=10, random_state=7)
    Z = svd.fit_transform(X)
    return vec, svd, Z

def fit_gmm_weights(Z: np.ndarray, k_max: int = 6, seed: int = 7) -> Tuple[GaussianMixture, np.ndarray, np.ndarray]:
    k = max(2, min(k_max, max(2, Z.shape[0] // 10)))
    gmm = GaussianMixture(n_components=k, covariance_type="full", random_state=seed)
    gmm.fit(Z)
    resp = gmm.predict_proba(Z)
    weights = resp.mean(0)
    return gmm, resp, weights

# ----------------------------- Paraphrasing -----------------------------
SYN = {
    r"\bhelped\b": ["assisted", "supported", "aided"],
    r"\bplan\b": ["roadmap", "plan of action", "workplan"],
    r"\bpractice\b": ["rehearse", "drill", "exercise"],
    r"\bresume\b": ["CV", "résumé"],
    r"\bimproved\b": ["enhanced", "increased", "boosted"],
    r"\breduced\b": ["lowered", "decreased", "dropped"],
    r"\bdesigned\b": ["built", "created", "architected"],
    r"\bbuilt\b": ["developed", "constructed", "assembled"],
    r"\bimplemented\b": ["deployed", "delivered", "executed"],
    r"\banalytics\b": ["analysis", "data analytics", "analytics work"],
    r"\bmodel\b": ["predictive model", "ML model", "statistical model"],
    r"\bETL\b": ["data pipelines", "ingestion pipelines", "data workflows"],
    r"\bled\b": ["headed", "led", "directed"],
    r"\bmanage(d|s)?\b": ["oversee\\1", "manage\\1", "coordinate\\1"],
    r"\boptimi[sz]e(d|s|ing)?\b": ["tune\\1", "refine\\1", "improve\\1"],
}

def light_paraphrase(text: str, swaps_min=2, swaps_max=6, allow_jitter=False) -> str:
    t = text
    keys = list(SYN.keys())
    random.shuffle(keys)
    n_swaps = random.randint(swaps_min, swaps_max)
    for k in keys[:n_swaps]:
        t = re.sub(k, random.choice(SYN[k]), t, flags=re.IGNORECASE)
    if allow_jitter and random.random() < 0.12:
        t = random.choice(["In summary, ", "In practice, ", "For example, ", ""]) + t
    if allow_jitter and random.random() < 0.12:
        t = t + random.choice(["", " The outcome was positive."])
    return t

# ----------------------------- Length utilities -----------------------------
def wc(s: str) -> int:
    return len(s.split())

def cc(s: str) -> int:
    return len(s)

def within_len_bounds(seed: str, cand: str, word_tol: float, char_tol: float) -> bool:
    w0, c0 = max(1, wc(seed)), max(1, cc(seed))
    w1, c1 = wc(cand), cc(cand)
    return (abs(w1 - w0) / w0 <= word_tol) and (abs(c1 - c0) / c0 <= char_tol)

# ----------------------------- Quotas -----------------------------
def compute_class_component_quotas(df: pd.DataFrame, resp: np.ndarray, n_each: int) -> Dict[Tuple[Optional[str], int], int]:
    if "class" in df.columns:
        classes = df["class"].astype("category")
        class_vals = classes.cat.categories.tolist()
    else:
        classes = pd.Series([None]*len(df))
        class_vals = [None]

    labels = resp.argmax(1)
    k = resp.shape[1]

    class_counts = classes.value_counts(dropna=False)
    class_props = class_counts / max(1, class_counts.sum())
    class_targets = {cls: int(round(class_props.get(cls, 0) * n_each)) for cls in class_props.index}
    diff = n_each - sum(class_targets.values())
    if diff != 0:
        for cls in class_props.sort_values(ascending=False).index[:abs(diff)]:
            class_targets[cls] += int(np.sign(diff))

    quotas: Dict[Tuple[Optional[str], int], int] = {}
    for cls in class_vals:
        if cls not in class_targets:
            continue
        idx = (classes == cls).to_numpy()
        if idx.sum() == 0:
            continue
        resp_cls = resp[idx]
        comp_props = resp_cls.mean(0)
        comp_targets = np.round(comp_props / max(1e-12, comp_props.sum()) * class_targets.get(cls, 0)).astype(int)
        d = class_targets.get(cls, 0) - int(comp_targets.sum())
        if d != 0:
            bump = np.argsort(-comp_props)[:abs(d)]
            comp_targets[bump] += int(np.sign(d))
        for c in range(k):
            tgt = int(comp_targets[c])
            if tgt > 0:
                quotas[(cls, c)] = tgt
    return quotas

# ----------------------------- Core Augmentation -----------------------------
def augment_lenaware(
    sheet: pd.DataFrame,
    text_col: str,
    kind: str,
    id_col: Optional[str],
    n_each: int = 500,
    seed: int = 7,
    lo: float = 0.84,
    hi: float = 0.99,
    min_lo: float = 0.80,
    nearest_tol: float = 0.0,
    batch_size: int = 64,
    max_tries_per_item: int = 6,
    k_max: int = 6,
    swaps_min: int = 3,
    swaps_max: int = 7,
    len_word_tol: float = 0.15,
    len_char_tol: float = 0.15,
    allow_jitter: bool = False,
    # adaptivity
    adapt_after_attempts: int = 2000,
    relax_lo_step: float = 0.005,
    widen_len_step: float = 0.02,
    len_tol_cap: float = 0.30,
    rescue_after_failures: int = 8,
    # per-seed strategy
    max_per_seed: int = 5,
    per_seed_dupe_times: int = 2,
    hard_attempt_budget: int = 10000,
    # fallback
    fallback_after_rescues: int = 2
) -> pd.DataFrame:

    random.seed(seed); np.random.seed(seed)
    sheet = sheet.copy()
    if "class" not in sheet.columns:
        sheet["class"] = pd.NA

    corpus = sheet[text_col].astype(str).fillna("").map(clean_text)
    sheet[text_col] = corpus

    vec, svd, Z = build_embedder_hash(corpus)
    k_max = min(k_max, max(2, len(corpus) // 15))
    gmm, resp, weights = fit_gmm_weights(Z, k_max=k_max, seed=seed)
    labels = resp.argmax(1)
    k = gmm.n_components

    # index by component (+ class if present)
    if "class" in sheet.columns:
        classes = sheet["class"].astype("category")
        class_vals = classes.cat.categories.tolist()
    else:
        classes = pd.Series([None]*len(sheet))
        class_vals = [None]

    idx_by_cls_comp: Dict[Tuple[Optional[str], int], np.ndarray] = {}
    for c in range(k):
        comp_mask = (labels == c)
        for cls in class_vals:
            mask = (classes == cls).to_numpy() & comp_mask if cls is not None else comp_mask
            if mask.sum() > 0:
                idx_by_cls_comp[(cls, c)] = np.where(mask)[0]

    quotas = compute_class_component_quotas(sheet, resp, n_each)
    produced = {(cls, c): 0 for (cls, c) in quotas.keys()}
    failures = {(cls, c): 0 for (cls, c) in quotas.keys()}
    rescue_count = 0
    global_fallback_mode = False

    # per-seed caps & counters
    n_seeds = len(corpus)
    target_per_seed = int(math.ceil(n_each / max(1, n_seeds)))  # e.g., ceil(300/80)=4
    max_per_seed = max(max_per_seed, target_per_seed)           # ensure not below needed avg
    produced_per_seed: Dict[int, int] = {i: 0 for i in range(n_seeds)}
    dupes_per_seed: Dict[int, Dict[str, int]] = {i: {} for i in range(n_seeds)}

    # seed indices by bucket
    seed_indices: Dict[Tuple[Optional[str], int], List[int]] = {}
    for key, idxs in idx_by_cls_comp.items():
        seed_indices[key] = list(idxs)

    out_rows: List[pd.Series] = []
    overall = tqdm(total=n_each, desc=f"Generating {n_each} items", leave=True)

    def synth_id_for(kind_local: str, base: str, idx_local: int) -> str:
        prefix = "response_id" if kind_local == "therapy" else "resume_id"
        return f"{prefix}::AUG::{base}::{idx_local}"

    total_attempts = 0

    def can_accept_seed_text(seed_i: int, s: str) -> bool:
        ksig = uniq_key(s)
        return dupes_per_seed[seed_i].get(ksig, 0) < per_seed_dupe_times

    def record_seed_text(seed_i: int, s: str):
        ksig = uniq_key(s)
        dupes_per_seed[seed_i][ksig] = dupes_per_seed[seed_i].get(ksig, 0) + 1
        produced_per_seed[seed_i] += 1

    def rescue_and_report():
        nonlocal rescue_count
        starving = [k for k, tgt in quotas.items() if produced.get(k, 0) < tgt and failures.get(k, 0) >= rescue_after_failures]
        donors   = [k for k, tgt in quotas.items() if produced.get(k, 0) < tgt and failures.get(k, 0) < rescue_after_failures and len(seed_indices.get(k, [])) > 0]
        if not starving or not donors:
            return False
        to_move = sum(quotas[s] - produced.get(s, 0) for s in starving)
        for s in starving:
            quotas[s] = produced.get(s, 0)
        if to_move <= 0:
            return False
        weights_donors = np.array([max(1, len(seed_indices[d])) for d in donors], dtype=float)
        if not np.isfinite(weights_donors.sum()) or weights_donors.sum() <= 0:
            return False
        weights_donors /= weights_donors.sum()
        add = np.random.multinomial(to_move, weights_donors)
        for d, inc in zip(donors, add):
            quotas[d] += int(inc)
        rescue_count += 1
        print(f"[RESCUE] Reallocated {to_move} items from starved buckets to donors. (rescue_count={rescue_count})")
        return True

    # round-robin over seeds helps spread work; we’ll also bias to seeds below max_per_seed
    seed_order = list(range(n_seeds))
    random.shuffle(seed_order)
    seed_ptr = 0

    while len(out_rows) < n_each:
        if not global_fallback_mode:
            if not any(quotas.get(k_, 0) - produced.get(k_, 0) > 0 for k_ in quotas):
                break

        # choose a seed index
        if global_fallback_mode:
            # any seed under its per-seed cap
            candidates = [i for i in range(n_seeds) if produced_per_seed[i] < max_per_seed]
            if not candidates:
                break
            seed_i = random.choice(candidates)
        else:
            # try to pick a seed from non-empty buckets with remaining quotas and still under per-seed cap
            remaining_buckets = [(k_, quotas[k_] - produced[k_]) for k_ in quotas if quotas[k_] - produced[k_] > 0]
            bucket_choices = []
            bucket_weights = []
            for key, rem in remaining_buckets:
                idxs = seed_indices.get(key, [])
                # filter seed idxs by per-seed cap
                idxs = [i for i in idxs if produced_per_seed[i] < max_per_seed]
                if len(idxs) > 0:
                    bucket_choices.append((key, idxs))
                    bucket_weights.append(rem * math.log(2 + len(idxs)))
            if not bucket_choices:
                if rescue_and_report() and any(quotas.get(k_, 0) - produced.get(k_, 0) > 0 for k_ in quotas):
                    continue
                global_fallback_mode = True
                print("[FALLBACK] Switching to global sampling to complete remaining items.")
                continue
            weights_choice = np.asarray(bucket_weights, dtype=float)
            if weights_choice.sum() > 0 and np.isfinite(weights_choice.sum()):
                weights_choice = weights_choice / weights_choice.sum()
                idx_pick = int(np.random.choice(len(bucket_choices), p=weights_choice))
            else:
                idx_pick = int(np.random.choice(len(bucket_choices)))
            key_bucket, idxs = bucket_choices[idx_pick]
            seed_i = int(np.random.choice(np.array(idxs, dtype=int)))

        if produced_per_seed[seed_i] >= max_per_seed:
            # advance round-robin if capped
            seed_ptr = (seed_ptr + 1) % n_seeds
            continue

        seed_row = sheet.iloc[seed_i]
        seed_text = seed_row[text_col]
        seed_vec_row = vec.transform([seed_text])

        accepted = False
        local_len_tol_w = len_word_tol
        local_len_tol_c = len_char_tol
        local_lo = lo
        attempts_here = 0
        best_cand = None
        best_sim = -1.0

        for t in range(max_tries_per_item):
            local_swaps_max = min(swaps_max + t // 2, swaps_max + 3)
            cands = [light_paraphrase(seed_text, swaps_min=swaps_min, swaps_max=local_swaps_max, allow_jitter=allow_jitter) for _ in range(batch_size)]
            X = vec.transform(cands)
            sims = cosine_similarity(seed_vec_row, X).ravel()
            order = np.argsort(-sims)

            # per-try micro-relax
            if t >= 2:
                local_len_tol_w = min(len_tol_cap, local_len_tol_w + 0.02)
                local_len_tol_c = min(len_tol_cap, local_len_tol_c + 0.02)
            if t >= 3:
                local_lo = max(min_lo, local_lo - 0.005)

            for j in order:
                sim = float(sims[j])
                cand = cands[j]
                if sim > best_sim:
                    best_sim, best_cand = sim, cand
                if not (local_lo - nearest_tol <= sim <= hi):
                    continue
                if not within_len_bounds(seed_text, cand, local_len_tol_w, local_len_tol_c):
                    continue
                if not can_accept_seed_text(seed_i, cand):
                    continue

                # accept
                record_seed_text(seed_i, cand)
                new_row = seed_row.copy()
                new_row[text_col] = cand
                new_row["source"] = "augmented"
                new_row["aug_of"] = seed_row.get(id_col, seed_i)
                if id_col is not None:
                    new_row[id_col] = synth_id_for(kind, str(seed_row.get(id_col, seed_i)), len(out_rows))
                else:
                    synth_col = "response_id" if kind == "therapy" else "resume_id"
                    new_row[synth_col] = synth_id_for(kind, str(seed_i), len(out_rows))
                out_rows.append(new_row)
                overall.update(1)
                accepted = True
                break

            attempts_here += batch_size
            total_attempts += batch_size
            if accepted:
                break

            # hard acceptance guard for this seed/item
            if attempts_here >= hard_attempt_budget and best_cand is not None:
                # enforce length; relax similarity a touch for this one-off
                if within_len_bounds(seed_text, best_cand, min(local_len_tol_w, len_tol_cap), min(local_len_tol_c, len_tol_cap)) and best_sim >= (lo - 2 * nearest_tol):
                    record_seed_text(seed_i, best_cand)
                    new_row = seed_row.copy()
                    new_row[text_col] = best_cand
                    new_row["source"] = "augmented"
                    new_row["aug_of"] = seed_row.get(id_col, seed_i)
                    if id_col is not None:
                        new_row[id_col] = synth_id_for(kind, str(seed_row.get(id_col, seed_i)), len(out_rows))
                    else:
                        synth_col = "response_id" if kind == "therapy" else "resume_id"
                        new_row[synth_col] = synth_id_for(kind, str(seed_i), len(out_rows))
                    out_rows.append(new_row)
                    overall.update(1)
                    accepted = True
                    print(f"[GUARD] Accepted best-forced cand (sim={best_sim:.3f}) for seed {seed_i} after {attempts_here} attempts.")
                    break

        # global adaptivity
        if not accepted:
            # nudge failures for relevant bucket if we know it
            if not global_fallback_mode:
                # find bucket for this seed
                comp = labels[seed_i]
                cls_val = classes.iloc[seed_i] if "class" in sheet.columns else None
                key_bucket = (cls_val, int(comp))
                if key_bucket in failures:
                    failures[key_bucket] = failures.get(key_bucket, 0) + 1
                    if failures[key_bucket] % rescue_after_failures == 0:
                        rescue_and_report()

        if total_attempts >= adapt_after_attempts and (overall.n % 25 == 0):
            prev_lo, prev_wtol, prev_ctol = lo, len_word_tol, len_char_tol
            lo = max(min_lo, lo - relax_lo_step)
            len_word_tol = min(len_tol_cap, len_word_tol + widen_len_step)
            len_char_tol = min(len_tol_cap, len_char_tol + widen_len_step)
            rescue_and_report()
            print(f"[ADAPT] lo {prev_lo:.3f}->{lo:.3f} | wtol {prev_wtol:.2f}->{len_word_tol:.2f} | ctol {prev_ctol:.2f}->{len_char_tol:.2f} | progress {overall.n}/{n_each}")
            total_attempts = 0
            if rescue_count >= fallback_after_rescues and not global_fallback_mode:
                global_fallback_mode = True
                print("[FALLBACK] Enabling global sampling due to repeated infeasible buckets.")

        # stop if reached
        if len(out_rows) >= n_each:
            break

    overall.close()

    out_df = pd.DataFrame(out_rows)
    if out_df.empty:
        out_df = sheet.head(0).copy()
    if "source" not in out_df.columns:
        out_df["source"] = "augmented"
    if "aug_of" not in out_df.columns:
        out_df["aug_of"] = pd.NA
    out_df[text_col] = out_df[text_col].astype(str).fillna("").map(clean_text)

    orig_cols = list(sheet.columns)
    extra_cols = [c for c in ["source", "aug_of"] if c not in orig_cols]
    out_df = out_df[orig_cols + extra_cols]
    return out_df

# ----------------------------- High-level wrapper -----------------------------
def run_one(in_path, out_path, preferred_col, kind, n_each, seed,
            lo, hi, min_lo, nearest_tol, batch_size, max_tries_per_item,
            k_max, swaps_min, swaps_max, len_word_tol, len_char_tol, allow_jitter,
            max_per_seed, per_seed_dupe_times, hard_attempt_budget, fallback_after_rescues):
    print(f"\n==> Augmenting {kind}: {os.path.basename(in_path)} → {os.path.basename(out_path)}")
    df = load_csv_utf8(in_path)

    text_col = detect_text_col(df, preferred_col)
    id_col = detect_id_col(kind, df)

    if "class" not in df.columns:
        df["class"] = pd.NA

    print(f"Detected text column: '{text_col}' | id column: '{id_col or '(none)'}' | rows={len(df)}")

    out = augment_lenaware(
        df, text_col, kind, id_col,
        n_each=n_each, seed=seed,
        lo=lo, hi=hi, min_lo=min_lo, nearest_tol=nearest_tol,
        batch_size=batch_size, max_tries_per_item=max_tries_per_item,
        k_max=k_max, swaps_min=swaps_min, swaps_max=swaps_max,
        len_word_tol=len_word_tol, len_char_tol=len_char_tol,
        allow_jitter=allow_jitter,
        max_per_seed=max_per_seed,
        per_seed_dupe_times=per_seed_dupe_times,
        hard_attempt_budget=hard_attempt_budget,
        fallback_after_rescues=fallback_after_rescues
    )
    save_csv_utf8(out, out_path)
    print(f"Saved {out_path}  (rows={len(out)})")
    if "class" in out.columns:
        print(out["class"].value_counts(dropna=False))

# ----------------------------- CLI -----------------------------
def main():
    ap = argparse.ArgumentParser(description="v5.4 Length-aware semantic augmentation (per-seed caps, guard, UTF-8 safe).")
    ap.add_argument("--sheet1", required=True)
    ap.add_argument("--sheet2", required=True)
    ap.add_argument("--out1", required=True)
    ap.add_argument("--out2", required=True)
    ap.add_argument("--n", type=int, default=500)
    ap.add_argument("--seed", type=int, default=7)

    # Similarity window
    ap.add_argument("--lo", type=float, default=0.84)
    ap.add_argument("--hi", type=float, default=0.99)
    ap.add_argument("--min_lo", type=float, default=0.80)
    ap.add_argument("--nearest_tol", type=float, default=0.0)

    # Performance
    ap.add_argument("--batch_size", type=int, default=64)
    ap.add_argument("--max_tries_per_item", type=int, default=6)

    # GMM & paraphrasing
    ap.add_argument("--k_max", type=int, default=6)
    ap.add_argument("--swaps_min", type=int, default=3)
    ap.add_argument("--swaps_max", type=int, default=7)

    # Form fidelity
    ap.add_argument("--len_word_tol", type=float, default=0.15)
    ap.add_argument("--len_char_tol", type=float, default=0.15)
    ap.add_argument("--disable_jitter", action="store_true")

    # Per-seed strategy
    ap.add_argument("--max_per_seed", type=int, default=5, help="Max augmented rows per original seed.")
    ap.add_argument("--per_seed_dupe_times", type=int, default=2, help="Max identical augmented string repeats per seed.")
    ap.add_argument("--hard_attempt_budget", type=int, default=10000, help="Candidate checks before guard forces best acceptance for the current item.")

    # Fallback
    ap.add_argument("--fallback_after_rescues", type=int, default=2, help="Enable global fallback after this many quota rescues.")

    args = ap.parse_args()

    allow_jitter = not args.disable_jitter

    run_one(
        args.sheet1, args.out1, "response_text", "therapy", args.n, args.seed,
        args.lo, args.hi, args.min_lo, args.nearest_tol, args.batch_size, args.max_tries_per_item,
        args.k_max, args.swaps_min, args.swaps_max, args.len_word_tol, args.len_char_tol, allow_jitter,
        args.max_per_seed, args.per_seed_dupe_times, args.hard_attempt_budget, args.fallback_after_rescues
    )

    run_one(
        args.sheet2, args.out2, "resume_text", "resumes", args.n, args.seed,
        args.lo, args.hi, args.min_lo, args.nearest_tol, args.batch_size, args.max_tries_per_item,
        args.k_max, args.swaps_min, args.swaps_max, args.len_word_tol, args.len_char_tol, allow_jitter,
        args.max_per_seed, args.per_seed_dupe_times, args.hard_attempt_budget, args.fallback_after_rescues
    )

if __name__ == "__main__":
    main()
