#!/usr/bin/env python3
"""
prepare_data_augmented.py — split originals first, then attach augmented rows to their parents (no leakage)

Low-impact retention tweak:
- New flag: --drop-aug-equal-to-parent
  If set, we ONLY remove augmented rows whose cleaned text is exactly equal to the cleaned text
  of *their own parent original* (same parent_id). We do NOT deduplicate among augmented rows.
  Originals are never removed. This preserves nearly all augmented data while avoiding trivial copies.

Usage example:
  python prepare_data_augmented.py \
    --therapy ../data/Sheet_1.csv \
    --resumes ../data/Sheet_2.csv \
    --aug-therapy ../data/10KSheet_1_augmented_v5_lenaware.csv \
    --aug-resumes ../data/10KSheet_2_augmented_v5_lenaware.csv \
    --drop-aug-equal-to-parent \
    --test-size 0.10 --seed 42
"""

import argparse
import json
import re
from typing import Optional, Tuple, List, Set

import pandas as pd
from sklearn.model_selection import train_test_split


# ---------------------------- cleaning & labeling ----------------------------
def clean_text(s: Optional[str]) -> str:
    if not s:
        return ""
    s = str(s).replace("\r", "\n")
    s = s.replace("\x97", "-").replace("\x96", "-").replace("\x95", "-").replace("\x92", "'")
    s = s.replace("\x8a", " ").replace("\x85", " ")
    s = re.sub(r"[^\x20-\x7E\n\t]", " ", s)
    s = re.sub(r"[ \t]+", " ", s)
    s = re.sub(r"\n{3,}", "\n\n", s)
    return s.strip()


def norm_label(lbl: Optional[str]) -> Optional[str]:
    if lbl is None:
        return None
    s = str(lbl).strip().lower()
    if s in {"1", "true", "yes", "y", "flag", "flagged", "positive", "risky", "unsafe", "danger"}:
        return "FLAGGED"
    if s in {"0", "false", "no", "n", "safe", "ok", "negative", "not flagged", "not_flagged"}:
        return "NOT FLAGGED"
    if s == "flagged":
        return "FLAGGED"
    if s in {"not flagged", "not_flagged"}:
        return "NOT FLAGGED"
    return None


# ---------------------------- helpers ----------------------------
def read_csv_safely(path: str) -> pd.DataFrame:
    """Prefer utf-8-sig, fall back to ISO-8859-1."""
    try:
        return pd.read_csv(path, encoding="utf-8-sig", engine="python")
    except Exception:
        return pd.read_csv(path, encoding="ISO-8859-1")


def detect_id_col(kind: str, df: pd.DataFrame) -> Optional[str]:
    if kind == "therapy":
        for c in ["response_id", "id", "uid", "text_id"]:
            if c in df.columns:
                return c
    else:
        for c in ["resume_id", "id", "uid", "text_id"]:
            if c in df.columns:
                return c
    return None


def ensure_text_id(df: pd.DataFrame, id_col: Optional[str], dataset: str) -> pd.Series:
    """
    Returns a Series 'text_id' derived from id_col if present, otherwise synthesizes a stable id.
    """
    if id_col and id_col in df.columns:
        tid = df[id_col].astype(str)
        tid = tid.fillna("").replace({"nan": ""})
        missing = tid.eq("")
        if missing.any():
            fill_vals = [f"{dataset}::ID::{i}" for i in df.index[missing]]
            tid.loc[missing] = fill_vals
        return tid
    return pd.Series([f"{dataset}::ID::{i}" for i in df.index], index=df.index)


def safe_str(x) -> Optional[str]:
    """Convert pandas NA/NaN to None; otherwise return str(x)."""
    if x is None:
        return None
    try:
        if pd.isna(x):
            return None
    except Exception:
        pass
    return str(x)


# ---------------------------- loaders ----------------------------
def load_original(path: str, kind: str) -> pd.DataFrame:
    df = read_csv_safely(path)
    print("---------------------------------------------")
    print(len(df), f"rows loaded from {path}")
    drop_cols = [c for c in df.columns if c.startswith("Unnamed")]
    df = df.drop(columns=drop_cols, errors="ignore")

    if kind == "therapy":
        text = df.get("response_text")
        id_col = detect_id_col(kind, df) or "response_id"
    else:
        text = df.get("resume_text")
        id_col = detect_id_col(kind, df) or "resume_id"

    out = pd.DataFrame(
        {
            "text_id": ensure_text_id(df, id_col, kind),
            "text": text,
            "class": df.get("class"),
            "dataset": kind,
            "source": "original",
            "parent_id": pd.NA,   # originals have no parent
        }
    )
    out["text"] = out["text"].fillna("").map(clean_text)
    out["class"] = out["class"].map(norm_label)
    out = out[(out["text"] != "") & (out["class"].isin(["FLAGGED", "NOT FLAGGED"]))].copy()
    return out


def load_augmented(path: str, kind: str) -> pd.DataFrame:
    """
    Loads labels-preserving augmented CSV: expects same schema as input + source='augmented', aug_of=<parent id>.
    Produces normalized columns:
        - text_id: current row id (response_id/resume_id or synthesized by augmenter)
        - parent_id: from 'aug_of' if present
        - source: from 'source' if present else 'augmented'
    """
    df = read_csv_safely(path)
    print("---------------------------------------------")
    print(len(df), f"rows loaded from {path}")
    drop_cols = [c for c in df.columns if c.startswith("Unnamed")]
    df = df.drop(columns=drop_cols, errors="ignore")

    # Detect text/id columns
    if kind == "therapy":
        text_col = "response_text" if "response_text" in df.columns else ("text" if "text" in df.columns else None)
    else:
        text_col = "resume_text" if "resume_text" in df.columns else ("text" if "text" in df.columns else None)
    if text_col is None:
        obj_cols = [c for c in df.columns if pd.api.types.is_object_dtype(df[c])]
        text_col = obj_cols[0] if obj_cols else df.columns[0]

    id_col = detect_id_col(kind, df)
    text_id = ensure_text_id(df, id_col, kind)
    parent_id = df["aug_of"] if "aug_of" in df.columns else pd.Series([pd.NA] * len(df), index=df.index)
    source = df["source"] if "source" in df.columns else pd.Series(["augmented"] * len(df), index=df.index)

    out = pd.DataFrame(
        {
            "text_id": text_id,
            "text": df[text_col],
            "class": df.get("class"),
            "dataset": kind,
            "source": source,
            "parent_id": parent_id,
        }
    )
    out["text"] = out["text"].astype(str).fillna("").map(clean_text)
    out["class"] = out["class"].map(norm_label)
    out = out[(out["text"] != "") & (out["class"].isin(["FLAGGED", "NOT FLAGGED"]))].copy()
    return out


# ---------------------------- jsonl writer ----------------------------
def to_jsonl(df: pd.DataFrame, path: str):
    with open(path, "w", encoding="utf-8") as f:
        for _, r in df.iterrows():
            rec = {
                "text_id": safe_str(r["text_id"]),
                "text": "" if pd.isna(r["text"]) else str(r["text"]),
                "label": safe_str(r["label"]) if "label" in r else safe_str(r["class"]),
                "dataset": safe_str(r["dataset"]),
                "source": safe_str(r["source"]) if "source" in r else "original",
                "parent_id": None if ("parent_id" not in r or pd.isna(r["parent_id"])) else safe_str(r["parent_id"]),
            }
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")


# ---------------------------- main logic ----------------------------
def main():
    ap = argparse.ArgumentParser(description="Prepare stratified train/test from originals; attach aug to parent split.")
    ap.add_argument("--therapy", required=True, help="Path to Sheet_1*.csv (therapy originals)")
    ap.add_argument("--resumes", required=True, help="Path to Sheet_2*.csv (resumes originals)")

    # Optional augmented inputs (labels-preserving from augmenter)
    ap.add_argument("--aug-therapy", default=None, help="Path to labels-preserving augmented therapy CSV")
    ap.add_argument("--aug-resumes", default=None, help="Path to labels-preserving augmented resumes CSV")

    ap.add_argument("--test-size", type=float, default=0.10)
    ap.add_argument("--seed", type=int, default=42)

    # New: only drop augmented rows that exactly equal their parent's original text
    ap.add_argument("--drop-aug-equal-to-parent", action="store_true",
                    help="If set, remove only aug rows whose cleaned text equals their parent's cleaned original text.")
    # Keep the old switch but discourage using it together
    ap.add_argument("--dedup-text", action="store_true",
                    help="[Deprecated here] Previous global dedup. Prefer --drop-aug-equal-to-parent.")
    args = ap.parse_args()

    # --- 1) Load originals and normalize ---
    th_orig = load_original(args.therapy, "therapy")
    rs_orig = load_original(args.resumes, "resumes")
    orig_df = pd.concat([th_orig, rs_orig], ignore_index=True)
    orig_df = orig_df.rename(columns={"class": "label"})

    if orig_df.empty or orig_df["label"].nunique() < 2:
        raise SystemExit(
            "No usable ORIGINAL data or only one label present after cleaning. "
            "Check CSVs and label normalization."
        )

    # --- 2) Split ONLY the originals (stratified) ---
    train_orig, test_orig = train_test_split(
        orig_df, test_size=args.test_size, stratify=orig_df["label"], random_state=args.seed
    )

    # Record parent id sets for anti-leak assignment of aug rows
    train_parents: Set[str] = set(train_orig["text_id"].astype(str).tolist())
    test_parents: Set[str] = set(test_orig["text_id"].astype(str).tolist())

    # --- 3) Load augmented (optional) ---
    aug_parts: List[pd.DataFrame] = []
    if args.aug_therapy:
        aug_parts.append(load_augmented(args.aug_therapy, "therapy"))
    if args.aug_resumes:
        aug_parts.append(load_augmented(args.aug_resumes, "resumes"))

    if len(aug_parts) > 0:
        aug_df = pd.concat(aug_parts, ignore_index=True)
        aug_df = aug_df.rename(columns={"class": "label"})
    else:
        aug_df = pd.DataFrame(columns=["text_id", "text", "label", "dataset", "source", "parent_id"])

    # --- 4) Assign augmented rows to parent split (anti-leakage) ---
    if not aug_df.empty:
        # Normalize to string for membership tests
        aug_df["parent_id"] = aug_df["parent_id"].astype(str)

        mask_train_parent = aug_df["parent_id"].isin(train_parents)
        mask_test_parent = aug_df["parent_id"].isin(test_parents)
        aug_train = aug_df[mask_train_parent].copy()
        aug_test = aug_df[mask_test_parent].copy()

        aug_rest = aug_df[~(mask_train_parent | mask_test_parent)].copy()
        if not aug_rest.empty:
            if aug_rest["label"].nunique() >= 2:
                aug_train_rest, aug_test_rest = train_test_split(
                    aug_rest, test_size=args.test_size, stratify=aug_rest["label"], random_state=args.seed
                )
            else:
                aug_train_rest, aug_test_rest = train_test_split(
                    aug_rest, test_size=args.test_size, random_state=args.seed
                )
            aug_train = pd.concat([aug_train, aug_train_rest], ignore_index=True)
            aug_test = pd.concat([aug_test, aug_test_rest], ignore_index=True)
    else:
        aug_train = pd.DataFrame(columns=["text_id", "text", "label", "dataset", "source", "parent_id"])
        aug_test = pd.DataFrame(columns=["text_id", "text", "label", "dataset", "source", "parent_id"])

    # --- 5) Final splits = originals + augmented (same-parent) ---
    train_df = pd.concat([train_orig, aug_train], ignore_index=True)
    test_df = pd.concat([test_orig, aug_test], ignore_index=True)

    # --- 6) Optional *minimal* filtering: only drop aug == parent original (per parent_id) ---
    if args.drop_aug_equal_to_parent:
        # Build parent_id -> parent_text maps from originals in each split
        train_parent_text = train_orig.set_index("text_id")["text"].to_dict()
        test_parent_text = test_orig.set_index("text_id")["text"].to_dict()

        def drop_equal_to_parent(df_split: pd.DataFrame, parent_text_map: dict, split_name: str) -> pd.DataFrame:
            is_aug = df_split["source"].astype(str).eq("augmented")
            df_aug = df_split[is_aug].copy()
            df_non_aug = df_split[~is_aug]

            # Normalize types
            df_aug["parent_id"] = df_aug["parent_id"].astype(str)
            # Compute mask: augmented text equals its parent's original text (exact, after cleaning already applied)
            parent_text_series = df_aug["parent_id"].map(parent_text_map).fillna("__NO_PARENT__")
            eq_mask = df_aug["text"].astype(str).eq(parent_text_series.astype(str))
            removed = int(eq_mask.sum())

            if removed > 0:
                print(f"[Filter equal-to-parent] Removed {removed} {split_name} augmented rows identical to their parent text.")
            return pd.concat([df_non_aug, df_aug[~eq_mask]], ignore_index=True)

        train_df = drop_equal_to_parent(train_df, train_parent_text, "train")
        test_df  = drop_equal_to_parent(test_df,  test_parent_text,  "test")

    if args.dedup_text:
        print("[WARN] --dedup-text is deprecated in this script. Prefer --drop-aug-equal-to-parent.")

    # Sanity checks
    if train_df["label"].nunique() < 2 or test_df["label"].nunique() < 2:
        print("[WARN] One of the splits ended up with a single class after augmentation/filters.")

    # --- 7) Save JSONL ---
    to_jsonl(train_df, "train_augmented.jsonl")
    to_jsonl(test_df, "test_augmented.jsonl")

    # --- 8) Helpful summary ---
    print(f"Saved train_augmented.jsonl ({len(train_df)}) and test_augmented.jsonl ({len(test_df)})")
    print("\n[Train] Label distribution:")
    print(train_df["label"].value_counts(normalize=True).round(3))
    print("\n[Test] Label distribution:")
    print(test_df["label"].value_counts(normalize=True).round(3))

    print("\n[Train] By dataset+source:")
    print(train_df.groupby(["dataset", "source", "label"]).size())

    print("\n[Test] By dataset+source:")
    print(test_df.groupby(["dataset", "source", "label"]).size())


if __name__ == "__main__":
    main()

