{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment Setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# Set CUDA_VISIBLE_DEVICES to 0 to make only the first GPU visible\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using [MeZO Runner](../example/mezo_runner/) on Supported Tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "print(\"Current working directory:\", os.getcwd())\n",
    "os.chdir('../example/mezo_runner/')\n",
    "print(\"New working directory:\", os.getcwd())\n",
    "\n",
    "!MODEL=facebook/opt-2.7b TASK=SST2 MODE=ft LR=1e-7 EPS=1e-3 STEPS=20000 EVAL_STEPS=4000 bash mezo.sh\n",
    "\n",
    "os.chdir('../../tutorial/')\n",
    "print(\"New working directory:\", os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Huggingface Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "from transformers import (\n",
    "    AutoTokenizer, \n",
    "    TrainingArguments,\n",
    "    DataCollatorForLanguageModeling\n",
    ")\n",
    "from datasets import load_dataset\n",
    "from zo2 import (\n",
    "    ZOConfig,\n",
    "    zo_hf_init,\n",
    ")\n",
    "from zo2.trainer.hf_transformers.trainer import ZOTrainer\n",
    "from zo2.trainer.hf_trl.sft_trainer import ZOSFTTrainer\n",
    "from zo2.utils import seed_everything"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameter\n",
    "zo_method = \"zo2\"\n",
    "eval_mode = False\n",
    "model_name = \"facebook/opt-2.7b\"\n",
    "verbose = True\n",
    "max_steps = 300\n",
    "learning_rate = 1e-7\n",
    "weight_decay = 1e-1\n",
    "zo_eps = 1e-3\n",
    "seed = 42\n",
    "offloading_device = \"cpu\"\n",
    "working_device = \"cuda:0\"\n",
    "max_train_data = None\n",
    "max_eval_data = None\n",
    "use_cache = True\n",
    "max_new_tokens = 50\n",
    "temperature = 1.0\n",
    "seed_everything(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ZO steps\n",
    "zo_config = ZOConfig(\n",
    "    method=\"mezo-sgd\", \n",
    "    zo2=zo_method==\"zo2\", \n",
    "    lr=learning_rate,\n",
    "    weight_decay=weight_decay,\n",
    "    eps=zo_eps,\n",
    "    offloading_device=offloading_device,\n",
    "    working_device=working_device,\n",
    ")\n",
    "\n",
    "# Load ZO model\n",
    "with zo_hf_init(zo_config):\n",
    "    from transformers import OPTForCausalLM\n",
    "    model = OPTForCausalLM.from_pretrained(model_name)\n",
    "    model.zo_init(zo_config)\n",
    "if zo_method != \"zo2\": \n",
    "    model = model.to(working_device)\n",
    "print(f\"Check if zo2 init correctly: {hasattr(model, 'zo_training')}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare dataset\n",
    "dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')\n",
    "\n",
    "# tokenizing dataset\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "block_size = tokenizer.model_max_length\n",
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"text\"])\n",
    "tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=[\"text\"])\n",
    "def group_texts(examples):\n",
    "    # Concatenate all texts.\n",
    "    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
    "    total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
    "    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
    "        # customize this part to your needs.\n",
    "    total_length = (total_length // block_size) * block_size\n",
    "    # Split by chunks of max_len.\n",
    "    result = {\n",
    "        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
    "        for k, t in concatenated_examples.items()\n",
    "    }\n",
    "    result[\"labels\"] = result[\"input_ids\"].copy()\n",
    "    return result\n",
    "lm_datasets = tokenized_datasets.map(\n",
    "    group_texts,\n",
    "    batched=True,\n",
    "    batch_size=1000,\n",
    "    num_proc=4,\n",
    ")\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# trainer init\n",
    "training_args = TrainingArguments(\n",
    "    \"test-trainer\", \n",
    "    max_steps=max_steps,\n",
    "    save_strategy=\"no\", \n",
    "    logging_steps=10,\n",
    ")\n",
    "\n",
    "trainer = ZOTrainer(\n",
    "    model,\n",
    "    training_args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "\n",
    "# 'ZOTrainer' provides the capability to register pre-hooks and post-hooks during zo_step\n",
    "def drop_invalid_data(model, inputs, loss):\n",
    "    # Extract projected_grad, handle both tensor and scalar cases\n",
    "    projected_grad = model.opt.projected_grad\n",
    "    if isinstance(projected_grad, torch.Tensor):\n",
    "        projected_grad_is_nan = torch.isnan(projected_grad).any()\n",
    "    else:\n",
    "        projected_grad_is_nan = projected_grad != projected_grad  # Check for NaN in scalars\n",
    "    if torch.isnan(loss) or projected_grad_is_nan:\n",
    "        tqdm.write(\"'loss': {} or 'projected_grad': {} is nan. Drop this step.\".format(\n",
    "            loss, model.opt.projected_grad\n",
    "        ))\n",
    "        model.opt.projected_grad = 0  # Reset projected_grad to prevent parameter updates\n",
    "    return model, inputs, loss\n",
    "trainer.register_zo2_training_step_post_hook(drop_invalid_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# trainer step\n",
    "trainer.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mezo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
