#!/usr/bin/env python3
# kl_validation_v2.py
# Upgraded validation:
# - Higher-rank SVD (default 50)
# - Per-component (GMM) word/char JS to couple semantics with form
# - Optional union-SVD mode (off by default)
# - Removes redundant embed call from v1

import os, json, argparse
from typing import Tuple, Dict, Any

import numpy as np
import pandas as pd

from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.mixture import GaussianMixture

# ----------------------------- IO -----------------------------
def load_csv_safe(path: str) -> pd.DataFrame:
    for enc in ("utf-8", "utf-8-sig", "latin-1"):
        try:
            return pd.read_csv(path, encoding=enc)
        except Exception:
            continue
    return pd.read_csv(path, engine="python")

def detect_text_col(df: pd.DataFrame, preferred: str) -> str:
    return preferred if preferred in df.columns else df.columns[0]

def ensure_outdir(path: str):
    os.makedirs(path, exist_ok=True)

# ----------------------------- Basic Features -----------------------------
def word_counts(series: pd.Series) -> np.ndarray:
    s = series.astype(str).fillna("")
    return s.str.split().str.len().to_numpy()

def char_counts(series: pd.Series) -> np.ndarray:
    s = series.astype(str).fillna("")
    return s.str.len().to_numpy()

# ----------------------------- Divergence Helpers -----------------------------
def _normalize(p: np.ndarray, eps: float) -> np.ndarray:
    p = np.asarray(p, dtype=float)
    p = np.clip(p, eps, None)
    p = p / p.sum()
    return p

def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> float:
    p = _normalize(p, eps); q = _normalize(q, eps)
    return float(np.sum(p * (np.log(p) - np.log(q))))

def js_divergence(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> float:
    p = _normalize(p, eps); q = _normalize(q, eps)
    m = 0.5 * (p + q)
    return 0.5 * kl_divergence(p, m, eps) + 0.5 * kl_divergence(q, m, eps)

# ----------------------------- Histograms -----------------------------
def histogram_probs(a: np.ndarray, b: np.ndarray, bins: int = 30) -> Tuple[np.ndarray, np.ndarray]:
    u = np.concatenate([a, b])
    lo = np.percentile(u, 0.0)
    hi = np.percentile(u, 99.5)
    if hi <= lo:
        hi = lo + 1.0
    edges = np.linspace(lo, hi, bins + 1)
    ha, _ = np.histogram(a, bins=edges)
    hb, _ = np.histogram(b, bins=edges)
    return ha.astype(float), hb.astype(float)

def divergence_report_hist(a: np.ndarray, b: np.ndarray, bins: int, eps: float) -> Dict[str, float]:
    ha, hb = histogram_probs(a, b, bins=bins)
    return {
        "KL(P||Q)": kl_divergence(ha, hb, eps),
        "KL(Q||P)": kl_divergence(hb, ha, eps),
        "JS": js_divergence(ha, hb, eps),
    }

# ----------------------------- Embedding & GMM -----------------------------
def build_space(series_orig: pd.Series, series_aug: pd.Series, svd_dim: int, seed: int, union: bool=False):
    vec = HashingVectorizer(ngram_range=(1,2), analyzer="word", alternate_sign=False, n_features=4096, norm="l2")
    X_orig = vec.transform(series_orig.astype(str).fillna(""))
    if union:
        X_aug = vec.transform(series_aug.astype(str).fillna(""))
        X_fit = np.vstack([X_orig, X_aug])
    else:
        X_fit = X_orig
    svd = TruncatedSVD(n_components=svd_dim, random_state=seed).fit(X_fit)
    Z_orig = svd.transform(X_orig)
    Z_aug  = svd.transform(vec.transform(series_aug.astype(str).fillna("")))
    return Z_orig, Z_aug

def gmm_on_original(Z_orig: np.ndarray, k: int, seed: int) -> GaussianMixture:
    k = max(2, min(12, k))
    gmm = GaussianMixture(n_components=k, covariance_type="full", random_state=seed)
    gmm.fit(Z_orig)
    return gmm

def mixture_weights(gmm: GaussianMixture, Z: np.ndarray) -> np.ndarray:
    return gmm.predict_proba(Z).mean(axis=0)

# ----------------------------- Per-component length JS -----------------------------
def per_component_length_js(gmm: GaussianMixture, Z_orig: np.ndarray, Z_aug: np.ndarray,
                            words_orig: np.ndarray, words_aug: np.ndarray,
                            chars_orig: np.ndarray, chars_aug: np.ndarray,
                            bins: int, eps: float) -> Dict[str, Any]:
    resp_o = gmm.predict_proba(Z_orig)
    resp_a = gmm.predict_proba(Z_aug)
    comp_o = resp_o.argmax(1)
    comp_a = resp_a.argmax(1)
    k = gmm.n_components

    per_comp = []
    for c in range(k):
        idx_o = (comp_o == c)
        idx_a = (comp_a == c)
        if idx_o.sum() == 0 or idx_a.sum() == 0:
            per_comp.append({"comp": int(c), "word_JS": None, "char_JS": None,
                             "orig_n": int(idx_o.sum()), "aug_n": int(idx_a.sum())})
            continue
        w_js = divergence_report_hist(words_orig[idx_o], words_aug[idx_a], bins=bins, eps=eps)["JS"]
        c_js = divergence_report_hist(chars_orig[idx_o],  chars_aug[idx_a],  bins=bins, eps=eps)["JS"]
        per_comp.append({"comp": int(c), "word_JS": float(w_js), "char_JS": float(c_js),
                         "orig_n": int(idx_o.sum()), "aug_n": int(idx_a.sum())})

    # summary (ignore None)
    word_vals = [d["word_JS"] for d in per_comp if d["word_JS"] is not None]
    char_vals = [d["char_JS"] for d in per_comp if d["char_JS"] is not None]
    summary = {
        "word_JS_max": float(max(word_vals)) if word_vals else None,
        "word_JS_mean": float(np.mean(word_vals)) if word_vals else None,
        "char_JS_max": float(max(char_vals)) if char_vals else None,
        "char_JS_mean": float(np.mean(char_vals)) if char_vals else None,
    }
    return {"per_component": per_comp, "summary": summary}

# ----------------------------- Pair validation -----------------------------
def validate_pair(df_orig: pd.DataFrame, df_aug: pd.DataFrame, text_col: str,
                  bins: int, k: int, eps: float, seed: int, svd_dim: int, union_svd: bool) -> Dict[str, Any]:

    words_o = word_counts(df_orig[text_col])
    words_a = word_counts(df_aug[text_col])
    chars_o = char_counts(df_orig[text_col])
    chars_a = char_counts(df_aug[text_col])

    hist_word = divergence_report_hist(words_o, words_a, bins=bins, eps=eps)
    hist_char = divergence_report_hist(chars_o, chars_a, bins=bins, eps=eps)

    Z_orig, Z_aug = build_space(df_orig[text_col], df_aug[text_col], svd_dim=svd_dim, seed=seed, union=union_svd)
    gmm = gmm_on_original(Z_orig, k=k, seed=seed)

    w_orig = mixture_weights(gmm, Z_orig)
    w_aug  = mixture_weights(gmm, Z_aug)
    sem = {
        "KL(weights P||Q)": kl_divergence(w_orig, w_aug, eps),
        "KL(weights Q||P)": kl_divergence(w_aug, w_orig, eps),
        "JS(weights)": js_divergence(w_orig, w_aug, eps),
        "k_components": float(gmm.n_components),
        "svd_dim": float(svd_dim),
        "union_svd": bool(union_svd),
    }

    percomp = per_component_length_js(gmm, Z_orig, Z_aug, words_o, words_a, chars_o, chars_a, bins=bins, eps=eps)

    attn = {
        "word_JS_gt_0.05": float(hist_word["JS"] > 0.05),
        "char_JS_gt_0.05": float(hist_char["JS"] > 0.05),
        "weights_JS_gt_0.05": float(sem["JS(weights)"] > 0.05),
        "percomp_word_JS_max_gt_0.05": float((percomp["summary"]["word_JS_max"] or 0) > 0.05),
        "percomp_char_JS_max_gt_0.05": float((percomp["summary"]["char_JS_max"] or 0) > 0.05),
    }

    return {
        "hist_word": hist_word,
        "hist_char": hist_char,
        "semantic_weights": sem,
        "per_component_length_JS": percomp,
        "attention_flags": attn,
        "counts": {"orig_n": int(df_orig.shape[0]), "aug_n": int(df_aug.shape[0])},
    }

# ----------------------------- CLI -----------------------------
def main():
    p = argparse.ArgumentParser(description="v2 KL/JS validation with per-component length checks.")
    p.add_argument("--sheet1", required=True)
    p.add_argument("--sheet2", required=True)
    p.add_argument("--aug1",   required=True)
    p.add_argument("--aug2",   required=True)
    p.add_argument("--text1",  default="response_text")
    p.add_argument("--text2",  default="resume_text")
    p.add_argument("--bins",   type=int, default=30)
    p.add_argument("--k",      type=int, default=5)
    p.add_argument("--eps",    type=float, default=1e-8)
    p.add_argument("--seed",   type=int, default=7)
    p.add_argument("--svd_dim", type=int, default=50)
    p.add_argument("--union_svd", action="store_true")
    p.add_argument("--outdir", required=True)
    args = p.parse_args()

    ensure_outdir(args.outdir)

    s1 = load_csv_safe(args.sheet1)
    s2 = load_csv_safe(args.sheet2)
    a1 = load_csv_safe(args.aug1)
    a2 = load_csv_safe(args.aug2)

    col1 = detect_text_col(s1, args.text1)
    col2 = detect_text_col(s2, args.text2)

    rep1 = validate_pair(s1, a1, col1, bins=args.bins, k=args.k, eps=args.eps, seed=args.seed, svd_dim=args.svd_dim, union_svd=args.union_svd)
    rep2 = validate_pair(s2, a2, col2, bins=args.bins, k=args.k, eps=args.eps, seed=args.seed, svd_dim=args.svd_dim, union_svd=args.union_svd)

    report = {
        "Sheet_1": rep1,
        "Sheet_2": rep2,
        "config": {
            "bins": args.bins, "k": args.k, "eps": args.eps, "seed": args.seed,
            "sheet1": args.sheet1, "sheet2": args.sheet2, "aug1": args.aug1, "aug2": args.aug2,
            "text1": col1, "text2": col2, "svd_dim": args.svd_dim, "union_svd": args.union_svd
        }
    }

    json_path = os.path.join(args.outdir, "kl_validation_report.json")
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(report, f, indent=2)

    md_path = os.path.join(args.outdir, "kl_validation_summary.md")
    def fmt_blk(title: str, d: Dict[str, Any]) -> str:
        pc = d["per_component_length_JS"]["summary"]
        return (
            f"### {title}\n"
            f"- Word lengths: KL(P||Q)={d['hist_word']['KL(P||Q)']:.4f}, "
            f"KL(Q||P)={d['hist_word']['KL(Q||P)']:.4f}, JS={d['hist_word']['JS']:.4f}\n"
            f"- Char lengths: KL(P||Q)={d['hist_char']['KL(P||Q)']:.4f}, "
            f"KL(Q||P)={d['hist_char']['KL(Q||P)']:.4f}, JS={d['hist_char']['JS']:.4f}\n"
            f"- Semantic GMM weights: KL(P||Q)={d['semantic_weights']['KL(weights P||Q)']:.4f}, "
            f"KL(Q||P)={d['semantic_weights']['KL(weights Q||P)']:.4f}, "
            f"JS={d['semantic_weights']['JS(weights)']:.4f}, "
            f"k={int(d['semantic_weights']['k_components'])}, "
            f"svd_dim={int(d['semantic_weights']['svd_dim'])}, union_svd={d['semantic_weights']['union_svd']}\n"
            f"- Per-component length JS (summary): "
            f"word_JS_max={pc['word_JS_max'] if pc['word_JS_max'] is not None else 'NA'}, "
            f"word_JS_mean={pc['word_JS_mean'] if pc['word_JS_mean'] is not None else 'NA'}, "
            f"char_JS_max={pc['char_JS_max'] if pc['char_JS_max'] is not None else 'NA'}, "
            f"char_JS_mean={pc['char_JS_mean'] if pc['char_JS_mean'] is not None else 'NA'}\n"
            f"- Attention flags: word_JS>0.05={int(d['attention_flags']['word_JS_gt_0.05'])}, "
            f"char_JS>0.05={int(d['attention_flags']['char_JS_gt_0.05'])}, "
            f"weights_JS>0.05={int(d['attention_flags']['weights_JS_gt_0.05'])}, "
            f"percomp_word_JS_max>0.05={int(d['attention_flags']['percomp_word_JS_max_gt_0.05'])}, "
            f"percomp_char_JS_max>0.05={int(d['attention_flags']['percomp_char_JS_max_gt_0.05'])}\n"
            f"- Counts: orig={d['counts']['orig_n']}, aug={d['counts']['aug_n']}\n"
        )
    with open(md_path, "w", encoding="utf-8") as f:
        f.write("# KL/JS Validation Summary (v2)\n\n")
        f.write(fmt_blk("Sheet_1 (therapy)", rep1) + "\n")
        f.write(fmt_blk("Sheet_2 (resumes)", rep2) + "\n")
        f.write("\n**Notes**\n")
        f.write("- v2 enforces tighter semantic space and checks per-component form drift.\n")
        f.write("- Aim for JS(weights) ≤ 0.05 and per-component word/char JS ≤ 0.05.\n")

    print("Saved report:")
    print(" -", json_path)
    print(" -", md_path)

if __name__ == "__main__":
    main()
