{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46ad390f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "import argparse, ast, json\n",
    "from pathlib import Path\n",
    "from typing import List, Dict, Any\n",
    "from GinSign import BERT_Grounder\n",
    "import torch\n",
    "from datasets import Dataset\n",
    "from transformers import (\n",
    "    BertTokenizerFast,\n",
    "    BertForTokenClassification,\n",
    "    DataCollatorForTokenClassification,\n",
    "    TrainingArguments,\n",
    "    Trainer,\n",
    "    EarlyStoppingCallback,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb74de51",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Paths you will need to change:\n",
    "# data_dir = Path('grounding_data_holdout_v2/predicate') # or whichever data you want to train on\n",
    "# output_dir = Path('GinSign_Predicates_Holdout_V2') # or wherever you want to save your model\n",
    "model_name = 'bert-base-uncased'\n",
    "epochs = 3\n",
    "batch = 16\n",
    "lr = 5e-5\n",
    "eval_steps = 500  \n",
    "early_stopping = 3\n",
    "early_stopping_threshold = 1e-6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cae5254",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def _load_split(split_dir: Path) -> Dataset:\n",
    "    files = sorted(split_dir.glob(\"*.jsonl\"))\n",
    "    if not files:\n",
    "        raise FileNotFoundError(f\"No .jsonl files in {split_dir}\")\n",
    "    rows: List[Dict[str, Any]] = []\n",
    "    for f in files:\n",
    "        rows.extend(json.loads(l) for l in f.open())\n",
    "    return Dataset.from_list(rows)\n",
    "\n",
    "\n",
    "def _prepare(tokenizer, ds: Dataset):\n",
    "    def convert(batch):\n",
    "        seqs = [p + s for p, s in zip(batch[\"prefix\"], batch[\"sentence\"])]\n",
    "        lbls = [pt + [-100]*len(s) for pt, s in zip(batch[\"prefix_target\"], batch[\"sentence\"])]\n",
    "        toks = tokenizer(seqs, is_split_into_words=True, truncation=True, padding=False)\n",
    "        aligned = []\n",
    "        for row_id, wids in enumerate(toks.word_ids(batch_index=i) for i in range(len(seqs))):\n",
    "            lab = lbls[row_id]\n",
    "            out = []\n",
    "            prev = None\n",
    "            for w in wids:\n",
    "                if w is None:\n",
    "                    out.append(-100)\n",
    "                elif w != prev:\n",
    "                    out.append(lab[w])\n",
    "                else:\n",
    "                    out.append(-100)\n",
    "                prev = w\n",
    "            aligned.append(out)\n",
    "        toks[\"labels\"] = aligned\n",
    "        return toks\n",
    "    return ds.map(convert, batched=True, remove_columns=[c for c in ds.column_names if c not in {\"sentence\",\"prefix\",\"prefix_target\"}])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ed66a90",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def train(data_dir: Path, out_dir: Path, *, model_name=\"bert-base-uncased\", epochs=3, batch=8,\n",
    "          lr=5e-5, eval_steps=500, patience=3, threshold=1e-6):\n",
    "    tok = BertTokenizerFast.from_pretrained(model_name)\n",
    "    tr_ds = _prepare(tok, _load_split(data_dir/\"train\"))\n",
    "    te_ds = _prepare(tok, _load_split(data_dir/\"test\"))\n",
    "\n",
    "    model = BertForTokenClassification.from_pretrained(model_name, num_labels=2)\n",
    "\n",
    "    args = TrainingArguments(\n",
    "        output_dir=str(out_dir),\n",
    "        num_train_epochs=epochs,\n",
    "        per_device_train_batch_size=batch,\n",
    "        per_device_eval_batch_size=batch,\n",
    "        learning_rate=lr,\n",
    "        weight_decay=0.01,\n",
    "        eval_strategy=\"steps\",\n",
    "        eval_steps=eval_steps,\n",
    "        save_strategy=\"steps\",\n",
    "        save_steps=eval_steps,\n",
    "        logging_strategy=\"steps\",\n",
    "        logging_steps=max(1, eval_steps//5),\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_loss\",\n",
    "        greater_is_better=False,\n",
    "        save_total_limit=2,\n",
    "        report_to=[],\n",
    "    )\n",
    "\n",
    "    cb = EarlyStoppingCallback(\n",
    "        early_stopping_patience=patience,\n",
    "        early_stopping_threshold=threshold,\n",
    "    )\n",
    "\n",
    "    Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=tr_ds,\n",
    "        eval_dataset=te_ds,\n",
    "        tokenizer=tok,\n",
    "        data_collator=DataCollatorForTokenClassification(tok),\n",
    "        callbacks=[cb],\n",
    "    ).train()\n",
    "\n",
    "    model.save_pretrained(out_dir)\n",
    "    tok.save_pretrained(out_dir)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ae297e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "train(data_dir, output_dir, model_name=model_name, epochs=epochs, batch=batch,\n",
    "        lr=lr, eval_steps=eval_steps, patience=early_stopping, threshold=early_stopping_threshold)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
