#!/usr/bin/env python
"""
Linear probe on frozen BioClinicalBERT for 4-way PED-X classification.
Embeddings = encoder pooler_output (768-d).
Classifier  = scikit-learn multinomial logistic regression.
"""

import argparse, json, numpy as np, pandas as pd, torch, joblib
from pathlib import Path
from datasets import Dataset
from transformers import AutoTokenizer, AutoModel
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from tqdm.auto import tqdm

# --------------------------------------------------------
LABEL2ID = {"NotExtrapolated": 0, "Partial": 1, "Full": 2, "Unlabeled": 3}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}

# ---------- helpers -------------------------------------
def _canon_txt(txt_dir: Path, cid: str) -> Path:
    return txt_dir / f"{cid}.txt"

def _filter_ok(df: pd.DataFrame, txt_dir: Path) -> pd.DataFrame:
    df = df.dropna(subset=["canon_id"]).copy()
    df["txt_file"] = df["canon_id"].apply(lambda c: _canon_txt(txt_dir, c))
    df = df[df["txt_file"].apply(Path.exists)]
    df["label"] = (
        df["label"].replace("", "NotExtrapolated").map(LABEL2ID).astype("int32")
    )
    df["txt_file"] = df["txt_file"].astype(str)
    assert len(df), "No rows left after filtering!"
    return df[["txt_file", "label"]]

def load_splits(split_dir: Path, txt_dir: Path):
    dfs = [pd.read_csv(split_dir / f"{s}.csv", dtype=str) for s in ["train", "dev", "test"]]
    return [Dataset.from_pandas(_filter_ok(df, txt_dir)) for df in dfs]

# ---------- embed all samples ----------------------------
@torch.no_grad()
def embed_dataset(dataset: Dataset, tokenizer, model, batch_size=8, device="cuda"):
    """
    Returns: (N, 768) numpy array of pooled embeddings + labels array.
    """
    all_emb, all_lab = [], []
    for i in tqdm(range(0, len(dataset), batch_size), desc="Embedding"):
        batch = dataset[i : i + batch_size]
        texts  = [Path(p).read_text(errors="ignore") for p in batch["txt_file"]]
        enc    = tokenizer(texts, max_length=512, truncation=True,
                           padding="max_length", return_tensors="pt").to(device)
        outputs = model(**enc)                       # pooled_output  -> outputs.pooler_output
        emb = outputs.pooler_output.detach().cpu().numpy()  # (B, 768)
        all_emb.append(emb)
        all_lab.append(np.asarray(batch["label"]))
    return np.vstack(all_emb), np.concatenate(all_lab)

# ---------------------------------------------------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--split_dir", required=True, type=Path)
    parser.add_argument("--txt_dir",   required=True, type=Path)
    parser.add_argument("--out_dir",   default=Path("checkpoints/bioclinicalbert_4cls_lp"), type=Path)
    parser.add_argument("--C",         default=1.0, type=float, help="Inverse reg-strength for LR")
    parser.add_argument("--batch_size", default=8, type=int)
    args = parser.parse_args()
    args.out_dir.mkdir(parents=True, exist_ok=True)

    # ---------- data ----------
    ds_train, ds_dev, ds_test = load_splits(args.split_dir, args.txt_dir)

    # ---------- frozen encoder ----------
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    encoder   = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT",
                                          add_pooling_layer=True).eval().to("cuda")

    # ---------- embed ----------
    X_tr, y_tr = embed_dataset(ds_train, tokenizer, encoder, args.batch_size, "cuda")
    X_dev, y_dev = embed_dataset(ds_dev, tokenizer, encoder, args.batch_size, "cuda")
    X_te, y_te = embed_dataset(ds_test, tokenizer, encoder, args.batch_size, "cuda")

    # ---------- linear probe ----------
    clf = LogisticRegression(
        C=args.C,
        penalty="l2",
        solver="lbfgs",
        max_iter=1000,
        multi_class="multinomial",
        class_weight="balanced",
        n_jobs=-1,
    ).fit(X_tr, y_tr)

    # eval on dev
    dev_pred = clf.predict(X_dev)
    dev_probs = clf.predict_proba(X_dev)
    dev_metrics = {
        "accuracy"   : accuracy_score(y_dev, dev_pred),
        "macro_f1"   : f1_score(y_dev, dev_pred, average="macro"),
        "weighted_f1": f1_score(y_dev, dev_pred, average="weighted"),
        "auc"        : roc_auc_score(
            np.eye(4)[y_dev], dev_probs, average="macro", multi_class="ovr"
        ),
    }
    print("Dev metrics:", json.dumps(dev_metrics, indent=2))

    # ---------- test ----------
    te_pred  = clf.predict(X_te)
    te_probs = clf.predict_proba(X_te)
    test_metrics = {
        "accuracy"   : accuracy_score(y_te, te_pred),
        "macro_f1"   : f1_score(y_te, te_pred, average="macro"),
        "weighted_f1": f1_score(y_te, te_pred, average="weighted"),
        "auc"        : roc_auc_score(np.eye(4)[y_te], te_probs, average="macro", multi_class="ovr"),
    }
    print("\nTest metrics:", json.dumps(test_metrics, indent=2))

    # ---------- save ----------
    joblib.dump(clf, args.out_dir / "linear_probe.joblib")
    np.save(args.out_dir / "train_emb.npy", X_tr)
    np.save(args.out_dir / "dev_emb.npy",   X_dev)
    np.save(args.out_dir / "test_emb.npy",  X_te)
    with open(args.out_dir / "test_metrics.json", "w") as f:
        json.dump(test_metrics, f, indent=2)

if __name__ == "__main__":
    main()
