{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import load_dataset\n",
    "from transformers import set_seed, AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "from peft import get_peft_model, MultitaskPromptTuningConfig, TaskType, MultitaskPromptTuningInit\n",
    "\n",
    "set_seed(42)\n",
    "\n",
    "model_name = \"google/flan-t5-base\"\n",
    "\n",
    "peft_config = MultitaskPromptTuningConfig(\n",
    "    tokenizer_name_or_path=model_name,\n",
    "    num_tasks=2,\n",
    "    task_type=TaskType.SEQ_2_SEQ_LM,\n",
    "    prompt_tuning_init=MultitaskPromptTuningInit.TEXT,\n",
    "    num_virtual_tokens=50,\n",
    "    num_transformer_submodules=1,\n",
    "    prompt_tuning_init_text=\"classify the following into either positive or negative, or entailment, neutral or contradiction:\",\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
    "model = get_peft_model(model, peft_config)\n",
    "\n",
    "model = model.cuda()\n",
    "\n",
    "\n",
    "def send_to_device(batch):\n",
    "    for i in batch:\n",
    "        batch[i] = batch[i].cuda()\n",
    "    return batch"
   ],
   "id": "cea05cca0bbb662"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def get_sst2(split: str):\n",
    "    examples = load_dataset(\"sst2\")[split]\n",
    "    result_examples = []\n",
    "    for example in examples:\n",
    "        result_examples.append({})\n",
    "\n",
    "        result_examples[-1][\"input\"] = example[\"sentence\"].strip() + \"</s>\"\n",
    "        result_examples[-1][\"output\"] = (\n",
    "            f\"positive{tokenizer.eos_token}\" if example[\"label\"] == 1 else f\"negative{tokenizer.eos_token}\"\n",
    "        )\n",
    "        result_examples[-1][\"task_id\"] = 0\n",
    "\n",
    "    return result_examples\n",
    "\n",
    "\n",
    "def get_mnli(split: str):\n",
    "    examples = load_dataset(\"multi_nli\")[split]\n",
    "    result_examples = []\n",
    "    for example in examples:\n",
    "        result_examples.append({})\n",
    "\n",
    "        result_examples[-1][\"input\"] = example[\"premise\"].strip() + \" \" + example[\"hypothesis\"].strip() + \"</s>\"\n",
    "\n",
    "        if example[\"label\"] == 0:\n",
    "            result_examples[-1][\"output\"] = f\"entailment{tokenizer.eos_token}\"\n",
    "        elif example[\"label\"] == 1:\n",
    "            result_examples[-1][\"output\"] = f\"neutral{tokenizer.eos_token}\"\n",
    "        else:\n",
    "            result_examples[-1][\"output\"] = f\"contradiction{tokenizer.eos_token}\"\n",
    "\n",
    "        result_examples[-1][\"task_id\"] = 1\n",
    "\n",
    "    return result_examples"
   ],
   "id": "8d40fdf7b6b790b7"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from typing import Tuple\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch\n",
    "\n",
    "\n",
    "class MyDataset(Dataset):\n",
    "    def __init__(self, split: str, mode: str = \"source\") -> None:\n",
    "        super().__init__()\n",
    "\n",
    "        if split == \"train\":\n",
    "            if mode == \"source\":\n",
    "                self.examples = get_sst2(split) + get_mnli(split)\n",
    "            elif mode == \"target\":\n",
    "                self.examples = get_sst2(split)\n",
    "        if split == \"val\":\n",
    "            self.examples = get_sst2(\"validation\")\n",
    "        if split == \"test\":\n",
    "            self.examples = get_sst2(\"validation\")\n",
    "\n",
    "    def __getitem__(self, index) -> dict:\n",
    "        return self.examples[index]\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return len(self.examples)\n",
    "\n",
    "    def __getitem__(self, index) -> dict:\n",
    "        return self.examples[index]\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return len(self.examples)\n",
    "\n",
    "\n",
    "def collate_fn(batch: dict) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "    input = [i[\"input\"] for i in batch]\n",
    "    input = tokenizer(input, add_special_tokens=False, return_tensors=\"pt\", padding=True)\n",
    "\n",
    "    output = [i[\"output\"] for i in batch]\n",
    "    output = tokenizer(output, add_special_tokens=False, return_tensors=\"pt\", padding=True).input_ids\n",
    "    output[output == tokenizer.pad_token_id] = -100\n",
    "\n",
    "    task_ids = [i[\"task_id\"] for i in batch]\n",
    "    task_ids = torch.tensor(task_ids)\n",
    "\n",
    "    return {\n",
    "        \"input_ids\": input.input_ids,\n",
    "        \"attention_mask\": input.attention_mask,\n",
    "        \"labels\": output,\n",
    "        \"task_ids\": task_ids,\n",
    "    }\n",
    "\n",
    "\n",
    "train = DataLoader(MyDataset(\"train\"), shuffle=True, batch_size=8, collate_fn=collate_fn)\n",
    "val = DataLoader(MyDataset(\"val\"), shuffle=False, batch_size=8, collate_fn=collate_fn)\n",
    "test = DataLoader(MyDataset(\"test\"), shuffle=False, batch_size=8, collate_fn=collate_fn)"
   ],
   "id": "b36940dbd4182c5"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## source training",
   "id": "916537140e8ba001"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from torch.optim.adamw import AdamW\n",
    "from transformers import get_cosine_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import f1_score"
   ],
   "id": "7a1b90ecb53d74ff"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "POSITIVE_TOKEN_ID = tokenizer(\" positive\", add_special_tokens=False)[\"input_ids\"][0]\n",
    "NEGATIVE_TOKEN_ID = tokenizer(\" negative\", add_special_tokens=False)[\"input_ids\"][0]\n",
    "\n",
    "\n",
    "def classify(batch):\n",
    "    batch = send_to_device(batch)\n",
    "    # we pass labels here since we need to generate and peft doesn't support generation yet.\n",
    "    # No clue how to get around this\n",
    "    scores = model(**batch).logits\n",
    "    preds = []\n",
    "    for i in range(scores.shape[0]):\n",
    "        if scores[i, 0, POSITIVE_TOKEN_ID] > scores[i, 0, NEGATIVE_TOKEN_ID]:\n",
    "            preds.append(POSITIVE_TOKEN_ID)\n",
    "        else:\n",
    "            preds.append(NEGATIVE_TOKEN_ID)\n",
    "    return preds\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def evaluate(model, data):\n",
    "    loss = 0\n",
    "    preds = []\n",
    "    golds = []\n",
    "\n",
    "    for batch in tqdm(data):\n",
    "        batch = send_to_device(batch)\n",
    "        loss += model(**batch).loss\n",
    "        golds.extend(batch[\"labels\"][:, 0].tolist())\n",
    "        preds.extend(classify(batch))\n",
    "\n",
    "    return loss / len(val), f1_score(golds, preds, pos_label=POSITIVE_TOKEN_ID)\n",
    "\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=1e-4)\n",
    "scheduler = get_cosine_schedule_with_warmup(optimizer, 200, len(train))\n",
    "\n",
    "n = 1000\n",
    "step = 0\n",
    "train_ = tqdm(train)\n",
    "\n",
    "val_loss, f1 = evaluate(model, val)\n",
    "print(\n",
    "    f\"\"\"\n",
    "before source training\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    ")\n",
    "\n",
    "for batch in train_:\n",
    "    if step % n == 0:\n",
    "        val_loss, f1 = evaluate(model, val)\n",
    "        print(\n",
    "            f\"\"\"\n",
    "step = {step}\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    "        )\n",
    "        model.save_pretrained(f\"checkpoints_source/{step}\")\n",
    "\n",
    "    step += 1\n",
    "    batch = send_to_device(batch)\n",
    "    loss = model(**batch).loss\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    train_.set_postfix(train_loss=loss)"
   ],
   "id": "e842d3edd5abffad"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## target training",
   "id": "c2f1c1a438717f6a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "train = DataLoader(MyDataset(\"train\", \"target\"), shuffle=True, batch_size=8, collate_fn=collate_fn)\n",
    "val = DataLoader(MyDataset(\"val\", \"target\"), shuffle=False, batch_size=8, collate_fn=collate_fn)\n",
    "test = DataLoader(MyDataset(\"test\", \"target\"), shuffle=False, batch_size=8, collate_fn=collate_fn)"
   ],
   "id": "bdbf40c8cf383b7d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "#### create a fresh model",
   "id": "6069ed2e71435807"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "peft_config = MultitaskPromptTuningConfig(\n",
    "    tokenizer_name_or_path=model_name,\n",
    "    num_tasks=1,\n",
    "    task_type=TaskType.SEQ_2_SEQ_LM,\n",
    "    prompt_tuning_init=MultitaskPromptTuningInit.EXACT_SOURCE_TASK,\n",
    "    prompt_tuning_init_state_dict_path=\"checkpoints_source/50000/adapter_model.bin\",\n",
    "    num_virtual_tokens=50,\n",
    "    num_transformer_submodules=1,\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
    "model = get_peft_model(model, peft_config)\n",
    "\n",
    "model = model.cuda()"
   ],
   "id": "52f0030815944442"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "optimizer = AdamW(model.parameters(), lr=1e-4)\n",
    "scheduler = get_cosine_schedule_with_warmup(optimizer, 200, len(train))\n",
    "\n",
    "n = 1000\n",
    "step = 0\n",
    "train_ = tqdm(train)\n",
    "\n",
    "val_loss, f1 = evaluate(model, val)\n",
    "print(\n",
    "    f\"\"\"\n",
    "before target training\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    ")\n",
    "\n",
    "for batch in train_:\n",
    "    if step % n == 0:\n",
    "        val_loss, f1 = evaluate(model, val)\n",
    "        print(\n",
    "            f\"\"\"\n",
    "step = {step}\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    "        )\n",
    "        model.save_pretrained(f\"checkpoints_target/{step}\")\n",
    "\n",
    "    step += 1\n",
    "    batch = send_to_device(batch)\n",
    "    loss = model(**batch).loss\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    scheduler.step()\n",
    "    train_.set_postfix(train_loss=loss)"
   ],
   "id": "d3b4d6449110418e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# load last checkpoint for now\n",
    "from peft import set_peft_model_state_dict\n",
    "\n",
    "sd_6000 = torch.load(\"checkpoints_target/6000/adapter_model.bin\")\n",
    "set_peft_model_state_dict(model, sd_6000)\n",
    "\n",
    "# evaluate val\n",
    "val_loss, f1 = evaluate(model, val)\n",
    "print(\n",
    "    f\"\"\"\n",
    "final\n",
    "val loss = {val_loss}\n",
    "f1 = {f1}\"\"\"\n",
    ")\n",
    "\n",
    "# evaluate test\n",
    "test_loss, f1 = evaluate(model, test)\n",
    "print(\n",
    "    f\"\"\"\n",
    "final\n",
    "test loss = {test_loss}\n",
    "f1 = {f1}\"\"\"\n",
    ")"
   ],
   "id": "874d5c907c9553d8"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
