#!/usr/bin/env bash
set -euo pipefail

# ---------- Config ----------
DATA_TRAIN="${DATA_TRAIN:-train_augmented.jsonl}"
DATA_TEST="${DATA_TEST:-test_augmented.jsonl}"
OUT_DIR="${OUT_DIR:-outputs}"
mkdir -p "$OUT_DIR" "$OUT_DIR/figures"

echo "---------------------------------------------"
python - "$DATA_TRAIN" <<'PY'
import sys
p=sys.argv[1]
print(f"{sum(1 for _ in open(p,'r',encoding='utf-8'))} rows loaded from {p}")
PY
echo "---------------------------------------------"
python - "$DATA_TEST" <<'PY'
import sys
p=sys.argv[1]
print(f"{sum(1 for _ in open(p,'r',encoding='utf-8'))} rows loaded from {p}")
PY
echo

# ---------- [1/6] Train C2 (distilroberta-base) ----------
echo "### [1/6] Training C2 (distilroberta-base) ..."
python 2NEW_train_supervised_roberta.py \
  --train "$DATA_TRAIN" \
  --test "$DATA_TEST" \
  --model distilroberta-base \
  --epochs 6 \
  --batch-size 16 \
  --lr 3e-5 \
  --fp16 \
  --out-dir "$OUT_DIR/roberta_cls" || true
echo

# ---------- [2/6] Inference with C2 ----------
echo "### [2/6] Inference with C2 (best checkpoint → preds_c2.jsonl) ..."
python 2NEW_pipeline_c_infer.py \
  --model-dir "$OUT_DIR/roberta_cls" \
  --test "$DATA_TEST" \
  --out "$OUT_DIR/preds_c2.jsonl"
echo

# Force C1 preds to come from the same best checkpoint (overwrite any train-time preds_c.jsonl)
echo "### [2/6b] Inference with C1 (best checkpoint → preds_c.jsonl) ..."
python 2NEW_pipeline_c_infer.py \
  --model-dir "$OUT_DIR/roberta_cls" \
  --test "$DATA_TEST" \
  --out "$OUT_DIR/preds_c.jsonl"
echo



# ---------- [3/6] Enrich C1/C2 predictions for blending ----------
echo "### [3/6] Enriching C1/C2 predictions for blending ..."
# Enrich C1 (preds_c.jsonl) with dataset/id/true/score
python - "$DATA_TEST" "$OUT_DIR/preds_c.jsonl" "$OUT_DIR/preds_c_enriched.jsonl" <<'PY'
import sys,json
test_path, c1_path, out_path = sys.argv[1:4]
gold=[json.loads(x) for x in open(test_path,'r',encoding='utf-8')]
c1  =[json.loads(x) for x in open(c1_path ,'r',encoding='utf-8')]
if len(gold)!=len(c1):
    raise SystemExit(f"Length mismatch: gold={len(gold)} vs c1={len(c1)}")
with open(out_path,'w',encoding='utf-8') as f:
    for i,(g,r) in enumerate(zip(gold,c1)):
        p=float(r["p_flagged"])
        f.write(json.dumps({
            "dataset":"test",
            "text_id":g.get("text_id",i),
            "true":g.get("true", g.get("label", None)),
            "p_flagged":p,
            "score":p,
            "pred":r["pred"]
        }, ensure_ascii=False)+"\n")
print(f"[DONE] wrote {out_path}")
PY

# Add "score" to C2 for blender
python - "$OUT_DIR/preds_c2.jsonl" "$OUT_DIR/preds_c2_forblend.jsonl" <<'PY'
import sys,json
inp,outp=sys.argv[1:3]
rows=[json.loads(x) for x in open(inp,'r',encoding='utf-8')]
for r in rows: r["score"]=float(r["p_flagged"])
with open(outp,'w',encoding='utf-8') as f:
    for r in rows: f.write(json.dumps(r, ensure_ascii=False)+"\n")
print(f"[DONE] wrote {outp}")
PY
echo

# ---------- [4/6] Blend C1 + C2 ----------
echo "### [4/6] Blending C1 + C2 ..."
python 2NEW_blend_classifiers.py \
  --c1 "$OUT_DIR/preds_c_enriched.jsonl" \
  --c2 "$OUT_DIR/preds_c2_forblend.jsonl" \
  --w1 0.5 --w2 0.5 \
  --thresh 0.50 \
  --out "$OUT_DIR/preds_c_blend.jsonl"
echo

# ---------- [5/6] Rank+Veto sweep ----------
echo "### [5/6] Rank+Veto sweep over (alpha, thresh, margin) ..."
alphas=(0.70 0.80 0.85 0.90 0.92)
thresholds=(0.58 0.60 0.62 0.64 0.66)
margins=(0.14 0.16 0.18 0.20)

echo "alpha,thresh,margin,acc,rec_flagged" > "$OUT_DIR/ensemble_twoC_rankveto_grid.csv"
for A in "${alphas[@]}"; do
  for T in "${thresholds[@]}"; do
    for M in "${margins[@]}"; do
      python 2NEW_ensemble_rank_veto.py \
        --bplus ../contrastive_augmentation_pipeline/outputs/pipeline_b_plus/preds.jsonl \
        --c "$OUT_DIR/preds_c_blend.jsonl" \
        --alpha "$A" --thresh "$T" --margin "$M" \
        --out "$OUT_DIR/preds_ensemble.jsonl" >/dev/null
      python - "$A" "$T" "$M" "$OUT_DIR/preds_ensemble.jsonl" >> "$OUT_DIR/ensemble_twoC_rankveto_grid.csv" <<'PY'
import json,sys
A=float(sys.argv[1]); T=float(sys.argv[2]); M=float(sys.argv[3]); path=sys.argv[4]
rows=[json.loads(x) for x in open(path,'r',encoding='utf-8')]
y=[1 if r.get("true")=="FLAGGED" else 0 for r in rows]
p=[1 if r["pred"]=="FLAGGED" else 0 for r in rows]
acc=sum(int(yi==pi) for yi,pi in zip(y,p))/len(y)
tp=sum(1 for yi,pi in zip(y,p) if yi==1 and pi==1); pos=sum(y)
rec=tp/max(1,pos)
print(f"{A:.2f},{T:.2f},{M:.2f},{acc:.4f},{rec:.4f}")
PY
    done
  done
done

echo
echo "Top 12 by accuracy (twoC + rank+veto):"
python - "$OUT_DIR/ensemble_twoC_rankveto_grid.csv" <<'PY'
import csv,sys
rows=list(csv.DictReader(open(sys.argv[1])))
rows.sort(key=lambda r: float(r["acc"]), reverse=True)
for r in rows[:12]:
    print(f"{r['alpha']},{r['thresh']},{r['margin']},acc={float(r['acc']):.4f},rec_FLAGGED={float(r['rec_flagged']):.4f}")
best=max(rows,key=lambda r: float(r["acc"]))
print("\n[SELECTED BEST]", best)
print(f"\nBest params → alpha={best['alpha']}  thresh={best['thresh']}  margin={best['margin']}")
PY

# Re-run best combo and keep it as final ensemble
python - <<'PY'
import csv,subprocess
rows=list(csv.DictReader(open("outputs/ensemble_twoC_rankveto_grid.csv")))
best=max(rows,key=lambda r: float(r["acc"]))
cmd=f"""python 2NEW_ensemble_rank_veto.py \
  --bplus ../contrastive_augmentation_pipeline/outputs/pipeline_b_plus/preds.jsonl \
  --c outputs/preds_c_blend.jsonl \
  --alpha {best['alpha']} --thresh {best['thresh']} --margin {best['margin']} \
  --out outputs/preds_ensemble.jsonl"""
print(subprocess.check_output(cmd, shell=True, text=True))
PY
echo

# ---------- [6/6] Final evaluation ----------
echo "### [6/6] Final evaluation (A/B) ..."
cp -f "$OUT_DIR/preds_c_blend.jsonl" "$OUT_DIR/preds_a.jsonl"
cp -f "$OUT_DIR/preds_ensemble.jsonl"  "$OUT_DIR/preds_b.jsonl"
python NEW_evaluate_fair.py
echo

# ---------- (Optional) Distinct prediction set sanity ----------
echo "Distinct prediction sets across sweep (sanity):"
python - <<'PY'
import csv, json, hashlib, os
rows=list(csv.DictReader(open("outputs/ensemble_twoC_rankveto_grid.csv")))
seen={}
for r in rows:
    tag=f"A{float(r['alpha']):.2f}_T{float(r['thresh']):.2f}_M{float(r['margin']):.2f}"
    os.system(f"python 2NEW_ensemble_rank_veto.py --bplus ../contrastive_augmentation_pipeline/outputs/pipeline_b_plus/preds.jsonl "
              f"--c outputs/preds_c_blend.jsonl --alpha {r['alpha']} --thresh {r['thresh']} --margin {r['margin']} "
              f"--out outputs/tmp_{tag}.jsonl >/dev/null")
    bits="".join(x.strip() for x in open(f"outputs/tmp_{tag}.jsonl",encoding='utf-8'))
    h=hashlib.md5(bits.encode()).hexdigest()
    seen.setdefault(h,[]).append(tag)
print(f"Distinct sets: {len(seen)}")
for i,(h,tags) in enumerate(seen.items(),1):
    head=", ".join(tags[:8])
    print(f"Set {i} ({len(tags)} combos): {head}" + (" ..." if len(tags)>8 else ""))
PY

echo
echo "[DONE] Repro pipeline completed."
