{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "%env TOKENIZERS_PARALLELISM=false"
   ],
   "id": "2531a783d3ec75cd"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Initialize PolyModel",
   "id": "d95eba880b0fa094"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import torch\n",
    "from transformers import (\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoTokenizer,\n",
    "    default_data_collator,\n",
    "    Seq2SeqTrainingArguments,\n",
    "    Seq2SeqTrainer,\n",
    ")\n",
    "from datasets import load_dataset, concatenate_datasets\n",
    "from peft import PolyConfig, get_peft_model, TaskType, PeftModel, PeftConfig\n",
    "\n",
    "model_name_or_path = \"google/flan-t5-xl\"\n",
    "\n",
    "r = 8  # rank of lora in poly\n",
    "n_tasks = 4  # number of tasks\n",
    "n_skills = 2  # number of skills (loras)\n",
    "n_splits = 4  # number of heads\n",
    "\n",
    "batch_size = 8\n",
    "lr = 5e-5\n",
    "num_epochs = 8"
   ],
   "id": "a93846a303e9185f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
    "base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)"
   ],
   "id": "623f645675525dfa"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "peft_config = PolyConfig(\n",
    "    task_type=TaskType.SEQ_2_SEQ_LM,\n",
    "    poly_type=\"poly\",\n",
    "    r=r,\n",
    "    n_tasks=n_tasks,\n",
    "    n_skills=n_skills,\n",
    "    n_splits=n_splits,\n",
    ")\n",
    "\n",
    "model = get_peft_model(base_model, peft_config)\n",
    "model.print_trainable_parameters()"
   ],
   "id": "60f8f30ce6b62b2b"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Prepare datasets\n",
    "\n",
    "For this example, we selected four `SuperGLUE` benchmark datasets: `boolq`, `multirc`, `rte`, and `wic`, each with a training set of 1,000 examples and an evaluation set of 100 examples."
   ],
   "id": "7790086810dc9c08"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# boolq\n",
    "boolq_dataset = (\n",
    "    load_dataset(\"super_glue\", \"boolq\")\n",
    "    .map(\n",
    "        lambda x: {\n",
    "            \"input\": f\"{x['passage']}\\nQuestion: {x['question']}\\nA. Yes\\nB. No\\nAnswer:\",\n",
    "            # 0 - False\n",
    "            # 1 - True\n",
    "            \"output\": [\"B\", \"A\"][int(x[\"label\"])],\n",
    "            \"task_name\": \"boolq\",\n",
    "        }\n",
    "    )\n",
    "    .select_columns([\"input\", \"output\", \"task_name\"])\n",
    ")\n",
    "print(\"boolq example: \")\n",
    "print(boolq_dataset[\"train\"][0])\n",
    "\n",
    "# multirc\n",
    "multirc_dataset = (\n",
    "    load_dataset(\"super_glue\", \"multirc\")\n",
    "    .map(\n",
    "        lambda x: {\n",
    "            \"input\": (\n",
    "                f\"{x['paragraph']}\\nQuestion: {x['question']}\\nAnswer: {x['answer']}\\nIs it\"\n",
    "                \" true?\\nA. Yes\\nB. No\\nAnswer:\"\n",
    "            ),\n",
    "            # 0 - False\n",
    "            # 1 - True\n",
    "            \"output\": [\"B\", \"A\"][int(x[\"label\"])],\n",
    "            \"task_name\": \"multirc\",\n",
    "        }\n",
    "    )\n",
    "    .select_columns([\"input\", \"output\", \"task_name\"])\n",
    ")\n",
    "print(\"multirc example: \")\n",
    "print(multirc_dataset[\"train\"][0])\n",
    "\n",
    "# rte\n",
    "rte_dataset = (\n",
    "    load_dataset(\"super_glue\", \"rte\")\n",
    "    .map(\n",
    "        lambda x: {\n",
    "            \"input\": (\n",
    "                f\"{x['premise']}\\n{x['hypothesis']}\\nIs the sentence below entailed by the\"\n",
    "                \" sentence above?\\nA. Yes\\nB. No\\nAnswer:\"\n",
    "            ),\n",
    "            # 0 - entailment\n",
    "            # 1 - not_entailment\n",
    "            \"output\": [\"A\", \"B\"][int(x[\"label\"])],\n",
    "            \"task_name\": \"rte\",\n",
    "        }\n",
    "    )\n",
    "    .select_columns([\"input\", \"output\", \"task_name\"])\n",
    ")\n",
    "print(\"rte example: \")\n",
    "print(rte_dataset[\"train\"][0])\n",
    "\n",
    "# wic\n",
    "wic_dataset = (\n",
    "    load_dataset(\"super_glue\", \"wic\")\n",
    "    .map(\n",
    "        lambda x: {\n",
    "            \"input\": (\n",
    "                f\"Sentence 1: {x['sentence1']}\\nSentence 2: {x['sentence2']}\\nAre '{x['word']}'\"\n",
    "                \" in the above two sentences the same?\\nA. Yes\\nB. No\\nAnswer:\"\n",
    "            ),\n",
    "            # 0 - False\n",
    "            # 1 - True\n",
    "            \"output\": [\"B\", \"A\"][int(x[\"label\"])],\n",
    "            \"task_name\": \"wic\",\n",
    "        }\n",
    "    )\n",
    "    .select_columns([\"input\", \"output\", \"task_name\"])\n",
    ")\n",
    "print(\"wic example: \")\n",
    "print(wic_dataset[\"train\"][0])"
   ],
   "id": "db9272b7916c0343"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# define a task2id map\n",
    "TASK2ID = {\n",
    "    \"boolq\": 0,\n",
    "    \"multirc\": 1,\n",
    "    \"rte\": 2,\n",
    "    \"wic\": 3,\n",
    "}\n",
    "\n",
    "\n",
    "def tokenize(examples):\n",
    "    inputs, targets = examples[\"input\"], examples[\"output\"]\n",
    "    features = tokenizer(inputs, max_length=512, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    labels = tokenizer(targets, max_length=2, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    labels = labels[\"input_ids\"]\n",
    "    labels[labels == tokenizer.pad_token_id] = -100\n",
    "    features[\"labels\"] = labels\n",
    "    features[\"task_ids\"] = torch.tensor([[TASK2ID[t]] for t in examples[\"task_name\"]]).long()\n",
    "    return features"
   ],
   "id": "fdaec718a4ec52e0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def get_superglue_dataset(\n",
    "        split=\"train\",\n",
    "        n_samples=500,\n",
    "):\n",
    "    ds = concatenate_datasets(\n",
    "        [\n",
    "            boolq_dataset[split].shuffle().select(range(n_samples)),\n",
    "            multirc_dataset[split].shuffle().select(range(n_samples)),\n",
    "            rte_dataset[split].shuffle().select(range(n_samples)),\n",
    "            wic_dataset[split].shuffle().select(range(n_samples)),\n",
    "        ]\n",
    "    )\n",
    "    ds = ds.map(\n",
    "        tokenize,\n",
    "        batched=True,\n",
    "        remove_columns=[\"input\", \"output\", \"task_name\"],\n",
    "        load_from_cache_file=False,\n",
    "    )\n",
    "    return ds"
   ],
   "id": "ac849181a2d16f1d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "As a toy example, we only select 1,000 from each subdataset for training and 100 each for eval.",
   "id": "ac901f7a8fe46ccc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "superglue_train_dataset = get_superglue_dataset(split=\"train\", n_samples=1000)\n",
    "superglue_eval_dataset = get_superglue_dataset(split=\"test\", n_samples=100)"
   ],
   "id": "1981c9bf3004cbd6"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Train and evaluate",
   "id": "93cccd8a71df6b03"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# training and evaluation\n",
    "def compute_metrics(eval_preds):\n",
    "    preds, labels = eval_preds\n",
    "    preds = [[i for i in seq if i != -100] for seq in preds]\n",
    "    labels = [[i for i in seq if i != -100] for seq in labels]\n",
    "    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
    "    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for pred, true in zip(preds, labels):\n",
    "        if pred.strip() == true.strip():\n",
    "            correct += 1\n",
    "        total += 1\n",
    "    accuracy = correct / total\n",
    "    return {\"accuracy\": accuracy}\n",
    "\n",
    "\n",
    "training_args = Seq2SeqTrainingArguments(\n",
    "    \"output\",\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    learning_rate=lr,\n",
    "    num_train_epochs=num_epochs,\n",
    "    eval_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    save_strategy=\"no\",\n",
    "    report_to=[],\n",
    "    predict_with_generate=True,\n",
    "    generation_max_length=2,\n",
    "    remove_unused_columns=False,\n",
    ")\n",
    "trainer = Seq2SeqTrainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=training_args,\n",
    "    train_dataset=superglue_train_dataset,\n",
    "    eval_dataset=superglue_eval_dataset,\n",
    "    data_collator=default_data_collator,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "trainer.train()"
   ],
   "id": "6dae0ffce495f307"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# saving model\n",
    "model_name_or_path = \"google/flan-t5-xl\"\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "model.save_pretrained(peft_model_id)"
   ],
   "id": "ccfa784c01b4eb50"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "!ls -lh $peft_model_id",
   "id": "1a22336013ddbde6"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Load and infer",
   "id": "4ba2407be1325dba"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"",
   "id": "ad2af3b8c1c93079"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)\n",
    "model = model.to(device)\n",
    "model = model.eval()"
   ],
   "id": "7beafcc9c5cc38e7"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "i = 5\n",
    "inputs = tokenizer(rte_dataset[\"validation\"][\"input\"][i], return_tensors=\"pt\")\n",
    "inputs[\"task_ids\"] = torch.LongTensor([TASK2ID[\"rte\"]])\n",
    "inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "print(rte_dataset[\"validation\"][\"input\"][i])\n",
    "print(rte_dataset[\"validation\"][\"output\"][i])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(**inputs, max_new_tokens=2)\n",
    "    print(outputs[0])\n",
    "    print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])"
   ],
   "id": "b5f489c0be69a51a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "a0fbd30ecd45c79a"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
