{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install --upgrade pip\n",
    "%pip install openicl\n",
    "# Restart the kernel after the installation is completed"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Accelerating OpenICL with 🤗 Accelerate: Distributed Data Parallel and Model Parallel"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In OpenICL, we use 🤗 [Accelerate](https://github.com/huggingface/accelerate) to implement Distributed Data Parallel (DDP) and Model Parallel. 🤗 [Accelerate](https://github.com/huggingface/accelerate) is a library that enables the same PyTorch code to be run across any distributed configuration by adding just few lines of code, to quickly quickly set up 🤗 Accelerate, on your machine(s) just run:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "accelerate config"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For more details on 🤗 Accelearte, you can check the [documentation](https://huggingface.co/docs/accelerate/index) here."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3-1 Distributed Data Parallel"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Distributed Data Parallel (DDP) implements data parallelism at the module level which can run across multiple machines. The recommended way to use DDP is to spawn one process for each model replica, where a model replica can span multiple devices. It is quite easy to use DDP in OpenICL after completing relevant settings through ```accelerate config```, just pass in the `Accelerator` instance in `Retriever` and `Inferencer`. The following are code and script examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_sst2_ddp.py\n",
    "# Example adapted from tutorial 1-4-1\n",
    "\n",
    "from openicl import DatasetReader, PromptTemplate, TopkRetriever, PPLInferencer\n",
    "from accelerate import Accelerator\n",
    "\n",
    "# Accelerate Prepare\n",
    "accelerator = Accelerator()\n",
    "\n",
    "# Define a DatasetReader, loading dataset from huggingface.\n",
    "data = DatasetReader('gpt3mix/sst2', input_columns=['text'], output_column='label')\n",
    "\n",
    "# SST-2 Template Example\n",
    "template = PromptTemplate(template={\n",
    "                                        0: '</E>Positive Movie Review: </text>',\n",
    "                                        1: '</E>Negative Movie Review: </text>' \n",
    "                                    },\n",
    "                          column_token_map={'text' : '</text>'},\n",
    "                          ice_token='</E>'\n",
    "           )\n",
    "\n",
    "# TopK Retriever\n",
    "retriever = TopkRetriever(data, ice_num=8, index_split='train', test_split='test', accelerator=accelerator)\n",
    "\n",
    "# Define a Inferencer\n",
    "inferencer = PPLInferencer(model_name='distilgpt2', accelerator=accelerator)\n",
    "\n",
    "# Inference\n",
    "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='ddp_sst2')\n",
    "\n",
    "# print(predictions)\n",
    "# Seeing results at ./icl_inference_output/ddp_sst2.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run_sst2.ddp.sh \n",
    "# Replace `${your_gpu_num}` and `${your_port_id}` with your gpu number and running port number respectively\n",
    "accelerate launch --num_processes ${your_gpu_num} --main_process_port ${your_port_id} test_sst2_ddp.py"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3-2 Model Parallel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
