#!/usr/bin/env python
"""
Linear probe on *frozen* Clinical-ModernBERT (Simonlee711/Clinical_ModernBERT)
for the 4-way PED-X classification task.
Only the final linear layer is trained.
Metrics: accuracy, macro-F1, weighted-F1, macro-AUC.
"""

from pathlib import Path
import argparse, json, numpy as np, pandas as pd, torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModel,
    TrainingArguments,
    Trainer,
    DefaultDataCollator,
    EarlyStoppingCallback,
)
from torch import nn
import evaluate
from sklearn.metrics import roc_auc_score

# ---------------------------------------------------------------------
LABEL2ID = {
    "NotExtrapolated": 0,
    "Partial"        : 1,
    "Full"           : 2,
    "Unlabeled"      : 3,
}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
N_LABELS = 4

# ---------------------------------------------------------------------
def _canon_txt(txt_dir: Path, cid: str) -> Path:
    return txt_dir / f"{cid}.txt"

def _filter_ok(df: pd.DataFrame, txt_dir: Path) -> pd.DataFrame:
    """Keep rows whose text file exists and map all 4 labels."""
    df = df.dropna(subset=["canon_id"]).copy()
    df["txt_file"] = df["canon_id"].apply(lambda c: _canon_txt(txt_dir, c))
    df = df[df["txt_file"].apply(Path.exists)]
    df["label"] = (
        df["label"]
        .replace("", "NotExtrapolated")   # fill blanks with majority class
        .map(LABEL2ID)
        .astype("int32")
    )
    df["txt_file"] = df["txt_file"].astype(str)
    assert len(df), "No rows left after filtering!"
    return df[["txt_file", "label"]]

def load_splits(split_dir: Path, txt_dir: Path):
    dfs = [pd.read_csv(split_dir / f"{s}.csv", dtype=str) for s in ["train", "dev", "test"]]
    return [Dataset.from_pandas(_filter_ok(df, txt_dir)) for df in dfs]

# ---------------------------------------------------------------------
class ModernBERTLinearProbe(nn.Module):
    """Frozen ModernBERT encoder + trainable linear head."""
    def __init__(self, encoder_name="Simonlee711/Clinical_ModernBERT", n_labels=N_LABELS):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        for p in self.encoder.parameters():       # freeze encoder
            p.requires_grad = False
        hid = self.encoder.config.hidden_size     # 768
        self.classifier = nn.Linear(hid, n_labels)

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        enc_out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls = enc_out.last_hidden_state[:, 0]     # [CLS] token
        logits = self.classifier(cls)
        if labels is not None:
            loss = nn.functional.cross_entropy(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}

# ---------------------------------------------------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--split_dir", required=True, type=Path)
    ap.add_argument("--txt_dir",   required=True, type=Path)
    ap.add_argument("--out_dir",   default=Path("checkpoints/modernbert_lp_4cls"), type=Path)
    ap.add_argument("--epochs",    default=10, type=int)
    ap.add_argument("--batch",     default=4,  type=int, help="per-device batch size")
    args = ap.parse_args()

    ds_train, ds_dev, ds_test = load_splits(args.split_dir, args.txt_dir)

    tokenizer = AutoTokenizer.from_pretrained("Simonlee711/Clinical_ModernBERT")

    def tok_fn(batch):
        texts = [Path(p).read_text(errors="ignore") for p in batch["txt_file"]]
        enc = tokenizer(texts, max_length=8192, truncation=True, padding="max_length")
        enc["labels"] = batch["label"]
        return enc

    ds_train = ds_train.map(tok_fn, batched=True, remove_columns=["txt_file","label"]).with_format("torch")
    ds_dev   = ds_dev  .map(tok_fn, batched=True, remove_columns=["txt_file","label"]).with_format("torch")
    ds_test  = ds_test .map(tok_fn, batched=True, remove_columns=["txt_file","label"]).with_format("torch")

    model = ModernBERTLinearProbe()

    # ---------------- metrics ----------------
    hf_acc = evaluate.load("accuracy")
    hf_f1  = evaluate.load("f1")

    def compute(eval_pred):
        logits, labels = eval_pred
        preds  = np.argmax(logits, axis=1)
        probs  = torch.softmax(torch.tensor(logits), dim=1).cpu().numpy()
        labels = labels.astype(np.int32)
        one_hot = np.eye(N_LABELS, dtype=np.int32)[labels]
        return {
            "accuracy"   : hf_acc.compute(predictions=preds, references=labels)["accuracy"],
            "macro_f1"   : hf_f1.compute(predictions=preds, references=labels, average="macro")["f1"],
            "weighted_f1": hf_f1.compute(predictions=preds, references=labels, average="weighted")["f1"],
            "auc"        : roc_auc_score(one_hot, probs, average="macro", multi_class="ovr"),
        }

    # ---------------- training args ----------------
    training_args = TrainingArguments(
        output_dir=args.out_dir,
        per_device_train_batch_size=args.batch,
        gradient_accumulation_steps=2,
        learning_rate=1e-3,              # high LR for tiny head
        weight_decay=0.0,
        num_train_epochs=args.epochs,
        eval_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=1000,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        fp16=False,
        seed=42,
        logging_steps=50,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=ds_train,
        eval_dataset=ds_dev,
        data_collator=DefaultDataCollator(return_tensors="pt"),
        compute_metrics=compute,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=4)],
    )

    trainer.train()
    metrics_test = trainer.evaluate(ds_test, metric_key_prefix="test")
    print(json.dumps(metrics_test, indent=2))

    args.out_dir.mkdir(parents=True, exist_ok=True)
    trainer.save_model(args.out_dir)
    tokenizer.save_pretrained(args.out_dir)
    (args.out_dir / "test_metrics.json").write_text(json.dumps(metrics_test, indent=2))

if __name__ == "__main__":
    main()
