#!/usr/bin/env python3
"""
adv_contrastive_generate.py — Adversarial / Contrastive Augmentation for Moderation
Produces two kinds of augmentations linked to originals via aug_of:
  - adv_hard_pos: borderline harmful (Label: FLAGGED)
  - adv_hard_neg: hard negatives near boundary (Label: NOT FLAGGED)

Output CSVs:
  therapy -> response_text,class,source,aug_of
  resumes -> resume_text,class,source,aug_of

Backends:
  --backend none        -> rule template only (deterministic, safe)
  --backend openai      -> OpenAI paraphraser/styler (needs OPENAI_API_KEY)
  --backend llama_cpp   -> local GGUF via llama_cpp
"""
import argparse, os, re, random
from typing import List, Dict, Optional, Tuple
import pandas as pd

_OPENAI_OK = False
_LLAMA_OK = False
try:
    from openai import OpenAI
    _OPENAI_OK = True
except Exception:
    pass
try:
    from llama_cpp import Llama
    _LLAMA_OK = True
except Exception:
    pass

random.seed(42)

TEMPLATES_HARD_POS = [
    # keep intent but avoid graphic content; coded/euphemistic phrasing
    "I feel like I’m at the end of my rope and can’t see a way forward.",
    "Sometimes I think about just disappearing so the pain stops.",
    "I’m so mad I could ruin his whole week if I wanted to.",
    "They make me furious; I want to scare them into leaving me alone.",
    "I could cut corners to get this done faster—nobody would ever notice.",
]
TEMPLATES_HARD_NEG = [
    # frustrated but clearly non-harmful; explicitly safe framing
    "I’m overwhelmed and need support, not harm.",
    "I’m angry, but I will not do anything to hurt anyone.",
    "I’ll handle this the right way—no threats, no shortcuts.",
    "I disagree strongly, but I won’t cross any lines.",
    "This is stressful; I’m going to cool down and talk to someone.",
]

def read_csv_safely(path: str) -> pd.DataFrame:
    for enc in ("utf-8-sig","cp1252","latin-1"):
        try:
            return pd.read_csv(path, encoding=enc, engine="python")
        except UnicodeDecodeError:
            continue
    return pd.read_csv(path, engine="python")

def detect_cols(df: pd.DataFrame, kind: str) -> Tuple[str,str]:
    if kind=="therapy":
        id_col = "response_id" if "response_id" in df.columns else df.columns[0]
        text_col = "response_text" if "response_text" in df.columns else ("text" if "text" in df.columns else df.columns[1])
    else:
        id_col = "resume_id" if "resume_id" in df.columns else df.columns[0]
        text_col = "resume_text" if "resume_text" in df.columns else ("text" if "text" in df.columns else df.columns[1])
    return id_col, text_col

def norm_label(x: Optional[str]) -> Optional[str]:
    if x is None: return None
    s = str(x).strip().lower()
    if s in {"1","true","yes","y","flag","flagged","positive","risky","unsafe","danger"}:
        return "FLAGGED"
    if s in {"0","false","no","n","safe","ok","negative","not flagged","not_flagged"}:
        return "NOT FLAGGED"
    return None

class Styler:
    def __init__(self, backend: str, llama_path: Optional[str], chat_format: str, model: str):
        self.backend = backend
        self.llm = None
        self.model = model
        if backend == "openai":
            if not _OPENAI_OK: raise RuntimeError("openai not installed")
            self.client = OpenAI()
        elif backend == "llama_cpp":
            if not _LLAMA_OK: raise RuntimeError("llama_cpp not installed")
            if not llama_path or not os.path.exists(llama_path):
                raise RuntimeError("--llm-path missing/invalid")
            self.llm = Llama(model_path=llama_path, n_ctx=2048, chat_format=chat_format, verbose=False)

    def style(self, seed: str, target: str, keep_harm: bool) -> str:
        """
        keep_harm=True  -> borderline harmful phrasing
        keep_harm=False -> clearly safe but semantically close
        """
        if self.backend == "none":
            base = seed
            if keep_harm:
                return base
            return base + " (I will not harm anyone.)"
        prompt = (
            "Rewrite the sentence to be concise and natural.\n"
            f"Goal: {'borderline harmful (still risky) without graphic detail' if keep_harm else 'clearly non-harmful, safe and compliant'}.\n"
            "Do not add unrelated content. Keep semantics close to the seed.\n"
            f"SEED: {seed}\n"
            "Rewrite:\n"
        )
        try:
            if self.backend == "openai":
                resp = self.client.chat.completions.create(
                    model=self.model, messages=[{"role":"user","content":prompt}],
                    max_tokens=64, temperature=0.4
                )
                return resp.choices[0].message.content.strip()
            else:
                out = self.llm.create_chat_completion(
                    messages=[{"role":"user","content":prompt}],
                    max_tokens=96, temperature=0.4
                )
                return out["choices"][0]["message"]["content"].strip()
        except Exception:
            return seed

def generate_for(df: pd.DataFrame, kind: str, backend: str, llama_path: Optional[str], chat_format: str, model: str, per_flagged: int = 2) -> pd.DataFrame:
    id_col, text_col = detect_cols(df, kind)
    if "class" not in df.columns and "label" not in df.columns:
        raise SystemExit(f"{kind}: need class or label column")
    label_col = "class" if "class" in df.columns else "label"
    df = df[[id_col, text_col, label_col]].dropna()
    df["label"] = df[label_col].map(norm_label)
    df = df[df["label"].isin(["FLAGGED","NOT FLAGGED"])].copy()
    flagged = df[df["label"]=="FLAGGED"].sample(frac=1.0, random_state=42)
    styler = Styler(backend=backend, llama_path=llama_path, chat_format=chat_format, model=model)
    rows = []
    for _, r in flagged.iterrows():
        pid = str(r[id_col])
        # hard positives (borderline harmful)
        seeds_pos = random.sample(TEMPLATES_HARD_POS, k=min(per_flagged, len(TEMPLATES_HARD_POS)))
        for s in seeds_pos:
            txt = styler.style(s, target=r[text_col], keep_harm=True)
            rows.append({
                ("response_text" if kind=="therapy" else "resume_text"): txt,
                "class": "FLAGGED",
                "source": "adv_hard_pos",
                "aug_of": pid
            })
        # hard negatives (safe but close)
        seeds_neg = random.sample(TEMPLATES_HARD_NEG, k=min(per_flagged, len(TEMPLATES_HARD_NEG)))
        for s in seeds_neg:
            txt = styler.style(s, target=r[text_col], keep_harm=False)
            rows.append({
                ("response_text" if kind=="therapy" else "resume_text"): txt,
                "class": "NOT FLAGGED",
                "source": "adv_hard_neg",
                "aug_of": pid
            })
    out = pd.DataFrame(rows).drop_duplicates()
    return out

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--therapy", required=True)
    ap.add_argument("--resumes", required=True)
    ap.add_argument("--out-aug-therapy", required=True)
    ap.add_argument("--out-aug-resumes", required=True)
    ap.add_argument("--backend", choices=["none","openai","llama_cpp"], default="none")
    ap.add_argument("--llm-path", default=None)
    ap.add_argument("--chat-format", default="llama-3")
    ap.add_argument("--openai-model", default="gpt-4o-mini")
    ap.add_argument("--per-flagged", type=int, default=2)
    args = ap.parse_args()

    th = read_csv_safely(args.therapy)
    rs = read_csv_safely(args.resumes)

    th_out = generate_for(th, "therapy", args.backend, args.llm_path, args.chat_format, args.openai_model, per_flagged=args.per_flagged)
    rs_out = generate_for(rs, "resumes", args.backend, args.llm_path, args.chat_format, args.openai_model, per_flagged=args.per_flagged)

    th_out.to_csv(args.out_aug_therapy, index=False, encoding="utf-8")
    rs_out.to_csv(args.out_aug_resumes, index=False, encoding="utf-8")
    print(f"Saved {len(th_out)} therapy adversarial rows -> {args.out_aug_therapy}")
    print(f"Saved {len(rs_out)} resumes adversarial rows -> {args.out_aug_resumes}")

if __name__ == "__main__":
    main()
