{
    "cells": [
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "992ea5bf-f420-44e3-9c56-ab04c086bf05",
                    "showTitle": false,
                    "title": ""
                },
                "id": "vticii3WR_HN"
            },
            "source": [
                "# Spark Dataframe to MDS\n",
                "In this tutorial, we will demonstrate how to use the streaming spark-converter to convert a spark dataframe to create a StreamingDataset. The users have the option to pass in a preprocessing job such as a tokenizer to the converter, which can be useful if materializing the intermediate dataframe is time consuming or taking extra development."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "ed551b02-c3a7-47dc-8aa2-8e6e501cf5e3",
                    "showTitle": false,
                    "title": ""
                },
                "id": "EV5xY06KR_HO"
            },
            "source": [
                "## Tutorial Covers\n",
                "1. Installation of libraries\n",
                "2. [Basic use-case] Convert spark dataframe to MDS format\n",
                "3. [Advanced use-case] convert spark dataframe into tokenized format and convert to MDS format."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "93f83c00-600f-4605-972c-5f0eb4a4152d",
                    "showTitle": false,
                    "title": ""
                },
                "id": "gyESiU8KR_HP"
            },
            "source": [
                "## Setup\n",
                "Let\u2019s start by making sure the right packages are installed and imported. We need to install the `mosaicml-streaming` package which installs the sufficient dependencies to run this tutorial."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "a3547314-ce67-4617-84b0-fafe366f82e4",
                    "showTitle": false,
                    "title": ""
                },
                "id": "QE2DHOK_R_HP"
            },
            "outputs": [],
            "source": [
                "%pip install --upgrade fsspec  datasets transformers"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "115c2bc7-94d6-4ada-926c-fc5bdfd6c29c",
                    "showTitle": false,
                    "title": ""
                },
                "id": "PeEMBLaSR_HP"
            },
            "outputs": [],
            "source": [
                "%pip uninstall -y mosaicml-streaming"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "zETa4qH0TnPE"
            },
            "outputs": [],
            "source": [
                "%pip install pyspark==3.4.1"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "zjlPAIBfhSbg"
            },
            "outputs": [],
            "source": [
                "import os\n",
                "import shutil\n",
                "from typing import Any, Sequence, Dict, Iterable, Optional\n",
                "from pyspark.sql import SparkSession\n",
                "import pandas as pd\n",
                "import numpy as np\n",
                "from tempfile import mkdtemp\n",
                "import datasets as hf_datasets\n",
                "from transformers import AutoTokenizer, PreTrainedTokenizerBase"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "G1I2CtGpRk40"
            },
            "source": [
                "We\u2019ll be using Streaming\u2019s `dataframe_to_mds()` method which converts the dataframe into the streaming MDS format."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "uzYHe6yYRzyV"
            },
            "outputs": [],
            "source": [
                "from streaming.base.converters import dataframe_to_mds"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "42e9ffbc-52a7-479d-b3a5-608d044d1f6a",
                    "showTitle": false,
                    "title": ""
                },
                "id": "Un4G3VdgR_HQ"
            },
            "source": [
                "## [Basic use-case] Convert spark dataframe to MDS format\n",
                "**Steps:**\n",
                "1. Create a Synthetic NLP dataset.\n",
                "2. Store the above dataset as a parquet file.\n",
                "3. Load the parquet file as spark dataframe.\n",
                "4. Convert the spark dataframe to MDS format.\n",
                "5. Load the MDS dataset and look at the output."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "90184dc8-cb85-4921-bc8e-e1525c6ee212",
                    "showTitle": false,
                    "title": ""
                },
                "id": "oKA2TTJ-R_HQ"
            },
            "source": [
                "### Create a Synthetic NLP dataset\n",
                "\n",
                "In this tutorial, we will be creating a synthetic number-saying dataset, i.e. converting a numbers from digits to words, for example, number `123` would spell as `one hundred twenty three`. The numbers are generated sequentially with a random positive/negative prefix sign.\n",
                "\n",
                "Let\u2019s import a utility functions to generate those synthetic number-saying dataset."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "2e58eb05-4271-4ea4-b7ce-2e5466a1405b",
                    "showTitle": false,
                    "title": ""
                },
                "id": "SMirmZEqR_HQ"
            },
            "outputs": [],
            "source": [
                "class NumberAndSayDataset:\n",
                "    \"\"\"Generate a synthetic number-saying dataset.\n",
                "\n",
                "    Converting a numbers from digits to words, for example, number 123 would spell as\n",
                "    `one hundred twenty three`. The numbers are generated randomly and it supports a number\n",
                "    up-to positive/negative approximately 99 Millions.\n",
                "\n",
                "    Args:\n",
                "        num_samples (int): number of samples. Defaults to 100.\n",
                "        column_names list[str]: A list of features' and target name. Defaults to ['number',\n",
                "            'words'].\n",
                "        seed (int): seed value for deterministic randomness.\n",
                "    \"\"\"\n",
                "\n",
                "    ones = (\n",
                "        'zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' +\n",
                "        'fifteen sixteen seventeen eighteen nineteen').split()\n",
                "\n",
                "    tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split()\n",
                "\n",
                "    def __init__(self,\n",
                "                 num_samples: int = 100,\n",
                "                 column_names: list[str] = ['number', 'words'],\n",
                "                 seed: int = 987) -> None:\n",
                "        self.num_samples = num_samples\n",
                "        self.column_encodings = ['int', 'str']\n",
                "        self.column_sizes = [8, None]\n",
                "        self.column_names = column_names\n",
                "        self._index = 0\n",
                "        self.seed = seed\n",
                "\n",
                "    def __len__(self) -> int:\n",
                "        return self.num_samples\n",
                "\n",
                "    def _say(self, i: int) -> list[str]:\n",
                "        if i < 0:\n",
                "            return ['negative'] + self._say(-i)\n",
                "        elif i <= 19:\n",
                "            return [self.ones[i]]\n",
                "        elif i < 100:\n",
                "            return [self.tens[i // 10 - 2]] + ([self.ones[i % 10]] if i % 10 else [])\n",
                "        elif i < 1_000:\n",
                "            return [self.ones[i // 100], 'hundred'] + (self._say(i % 100) if i % 100 else [])\n",
                "        elif i < 1_000_000:\n",
                "            return self._say(i // 1_000) + ['thousand'\n",
                "                                           ] + (self._say(i % 1_000) if i % 1_000 else [])\n",
                "        elif i < 1_000_000_000:\n",
                "            return self._say(\n",
                "                i // 1_000_000) + ['million'] + (self._say(i % 1_000_000) if i % 1_000_000 else [])\n",
                "        else:\n",
                "            assert False\n",
                "\n",
                "    def _get_number(self) -> int:\n",
                "        sign = (np.random.random() < 0.8) * 2 - 1\n",
                "        mag = 10**np.random.uniform(1, 4) - 10\n",
                "        return sign * int(mag**2)\n",
                "\n",
                "    def __iter__(self):\n",
                "        return self\n",
                "\n",
                "    def __next__(self) -> dict[str, Any]:\n",
                "        if self._index >= self.num_samples:\n",
                "            raise StopIteration\n",
                "        number = self._get_number()\n",
                "        words = ' '.join(self._say(number))\n",
                "        self._index += 1\n",
                "        return {\n",
                "            self.column_names[0]: number,\n",
                "            self.column_names[1]: words,\n",
                "        }\n",
                "\n",
                "    @property\n",
                "    def seed(self) -> int:\n",
                "        return self._seed\n",
                "\n",
                "    @seed.setter\n",
                "    def seed(self, value: int) -> None:\n",
                "        self._seed = value  # pyright: ignore\n",
                "        np.random.seed(self._seed)\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "c4tvDe8hTeMS"
            },
            "source": [
                "### Store the dataset as a parquet file"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "89707251-35c0-4c17-841d-2cc026e6f96f",
                    "showTitle": false,
                    "title": ""
                },
                "id": "ysofCNb3R_HR"
            },
            "outputs": [],
            "source": [
                "# Create a temporary directory\n",
                "local_dir = mkdtemp()\n",
                "\n",
                "syn_dataset = NumberAndSayDataset()\n",
                "df = pd.DataFrame.from_dict([record for record in syn_dataset])\n",
                "df.to_parquet(os.path.join(local_dir, 'synthetic_dataset.parquet'))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "e9e12fcc-fa78-4635-9e62-fb67ca5f520f",
                    "showTitle": false,
                    "title": ""
                },
                "id": "4hvJgJGxR_HR"
            },
            "source": [
                "### Load the parquet file as spark dataframe"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "10e4ZrH3he0q"
            },
            "outputs": [],
            "source": [
                "spark = SparkSession.builder.getOrCreate()\n",
                "pdf = spark.read.parquet(os.path.join(local_dir, 'synthetic_dataset.parquet'))"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "9a43d639-c0a8-438a-9c56-f748b6ca79ad",
                    "showTitle": false,
                    "title": ""
                },
                "id": "8JfeOLqWR_HS"
            },
            "source": [
                "Take a peek at the spark dataframe"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "iGkjgXtDhk6g"
            },
            "outputs": [],
            "source": [
                "pdf.show(5, truncate=False)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "2a2468b8-9538-42e6-b0c6-937a3de92210",
                    "showTitle": false,
                    "title": ""
                },
                "id": "pt5MYIrmR_HS"
            },
            "source": [
                "### Convert the spark dataframe to MDS format"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "aad931d4-8b97-4d5b-9c40-0f35e715c7b4",
                    "showTitle": false,
                    "title": ""
                },
                "id": "zRghnT00R_HS"
            },
            "outputs": [],
            "source": [
                "# Empty the MDS output directory\n",
                "out_path = os.path.join(local_dir, 'mds')\n",
                "shutil.rmtree(out_path, ignore_errors=True)\n",
                "\n",
                "# Specify the mandatory MDSWriter arguments `out` and `columns`.\n",
                "mds_kwargs = {'out': out_path, 'columns': {'number': 'int64', 'words':'str'}}\n",
                "\n",
                "# Convert the dataset to an MDS format. It divides the dataframe into 4 parts, one parts per worker and merge the `index.json` from 4 sub-parts into one in a parent directory.\n",
                "dataframe_to_mds(pdf.repartition(4), merge_index=True, mds_kwargs=mds_kwargs)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "327543b6-b2c8-4c60-bff3-bb4e5a980a28",
                    "showTitle": false,
                    "title": ""
                },
                "id": "MdKpnmkhR_HS"
            },
            "source": [
                "Let's check file structures in the output MDS dataset. One can see four directories and one `index.json` file. The `index.json` file contains the meta-data information about all four sub-directories."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "268c137e-6e38-4f31-8e51-0adfe3c82c44",
                    "showTitle": false,
                    "title": ""
                },
                "id": "CDbPsmUeR_HS"
            },
            "outputs": [],
            "source": [
                "%ls {out_path}"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "6d6aa896-9d5a-4763-aaf4-197f9170ec27",
                    "showTitle": false,
                    "title": ""
                },
                "id": "S6j3pr_IR_HS"
            },
            "source": [
                "### Load the MDS dataset using StreamingDataset and look at the output."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "ad4359ac-1652-48bd-b4db-ea034e180e01",
                    "showTitle": false,
                    "title": ""
                },
                "id": "qfpr2Df3R_HS"
            },
            "outputs": [],
            "source": [
                "from torch.utils.data import DataLoader\n",
                "import streaming\n",
                "from streaming import StreamingDataset\n",
                "\n",
                "# clean stale shared memory if any\n",
                "streaming.base.util.clean_stale_shared_memory()\n",
                "\n",
                "dataset = StreamingDataset(local=out_path, remote=None, batch_size=2, predownload=4)\n",
                "\n",
                "dataloader = DataLoader(dataset, batch_size=2, num_workers=1)\n",
                "\n",
                "for i, data in enumerate(dataloader):\n",
                "    print(data)\n",
                "    # Display only first 10 batches\n",
                "    if i == 10:\n",
                "        break"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "84e7df6b-b890-45ad-93de-00903a59c058",
                    "showTitle": false,
                    "title": ""
                },
                "id": "Uo1jLj1dR_HS"
            },
            "source": [
                "## [Advanced use-case] convert spark dataframe into tokenized format and convert to MDS format\n",
                "**Steps:**\n",
                "1. [Same as above] Create a Synthetic NLP dataset.\n",
                "2. [Same as above] Store the above dataset as a parquet file.\n",
                "3. [Same as above] Load the parquet file as spark dataframe.\n",
                "4. Create a user defined function which modifies the dataframe\n",
                "4. Convert the modified data into MDS format.\n",
                "5. Load the MDS dataset and look at the output."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "57ffd58d-118b-485b-9f5c-d42ecc6f81ad",
                    "showTitle": false,
                    "title": ""
                },
                "id": "geRPrzMhR_HS"
            },
            "source": [
                "### Create a user defined function which modifies the dataframe\n",
                "\n",
                "The user defined function should be an iterable function and it must yield an output as a dictionary with `key` as the column name and `value` as the output of that column. For example, in this tutorial, the `key` is `tokens` and `value` is the tokenized output in bytes. If an iterable function is defined, the user takes the full responsibility of providing the correct `columns` argument, in the case below, it should be\n",
                "\n",
                "columns={'tokens': 'bytes'}\n",
                "\n",
                "where `tokens` is the key created by the udf_iterator, and `bytes` represents the format of the field so that MDS chooses the proper encoding method."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "P1P7BsXsaDCy"
            },
            "source": [
                "\n",
                "Take a peek at the spark dataframe"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "mU5qfafGzKsQ"
            },
            "outputs": [],
            "source": [
                "pdf.show(5, truncate=False)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "JMeyEPoSaGgH"
            },
            "source": [
                "### Convert the spark dataframe to MDS format\n",
                "This time we supply the user defined iterable function and the associated function arguments. For the purpose of demonstration, the user defined tokenization function `pandas_processing_fn` is largely simplified. For practical applications, the users may want to have more involved preprocessing steps. For concatenation dataset and more process examples, users are referred to [Mosaic's LLM Foundry](https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/data/data.py)."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "H7Z5kjNie4ZR"
            },
            "outputs": [],
            "source": [
                "import os\n",
                "import warnings\n",
                "from typing import Dict, Iterable, Union\n",
                "import datasets as hf_datasets\n",
                "import pandas as pd\n",
                "import numpy as np\n",
                "from torch.utils.data import IterableDataset\n",
                "from transformers import PreTrainedTokenizerBase\n",
                "\n",
                "\n",
                "def pandas_processing_fn(df: pd.DataFrame, **args) -> Iterable[Dict[str, bytes]]:\n",
                "    \"\"\"\n",
                "    Parameters:\n",
                "    -----------\n",
                "    df : pandas.DataFrame\n",
                "        The input pandas DataFrame that needs to be processed.\n",
                "\n",
                "    **args : keyword arguments\n",
                "        Additional arguments to be passed to the 'process_some_data' function during processing.\n",
                "\n",
                "    Returns:\n",
                "    --------\n",
                "    iterable obj\n",
                "    \"\"\"\n",
                "    hf_dataset = hf_datasets.Dataset.from_pandas(df=df, split=args['split'])\n",
                "    tokenizer = AutoTokenizer.from_pretrained(args['tokenizer'])\n",
                "    # we will enforce length, so suppress warnings about sequences too long for the model\n",
                "    tokenizer.model_max_length = int(1e30)\n",
                "    max_length = args['concat_tokens']\n",
                "\n",
                "    for sample in hf_dataset:\n",
                "\n",
                "        buffer = []\n",
                "        for sample in hf_dataset:\n",
                "            encoded = tokenizer(sample['words'],\n",
                "                                truncation=False,\n",
                "                                padding=False)\n",
                "            iids = encoded['input_ids']\n",
                "            buffer = buffer + iids\n",
                "            while len(buffer) >= max_length:\n",
                "                concat_sample = buffer[:max_length]\n",
                "                buffer = []\n",
                "                yield {\n",
                "                    # convert to bytes to store in MDS binary format\n",
                "                    'tokens': np.asarray(concat_sample).tobytes()\n",
                "                }"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "b0b3cfd1-9753-44ca-9b03-55f29c60fb56",
                    "showTitle": false,
                    "title": ""
                },
                "id": "R9Xxr0CDR_HT"
            },
            "outputs": [],
            "source": [
                "# Empty the MDS output directory\n",
                "out_path = os.path.join(local_dir, 'mds')\n",
                "shutil.rmtree(out_path, ignore_errors=True)\n",
                "\n",
                "# Provide a MDS keyword args. Ensure `columns` field maps the output from iterable function (Tokenizer in this example)\n",
                "mds_kwargs = {'out': out_path, 'columns': {'tokens': 'bytes'}}\n",
                "\n",
                "# Tokenizer arguments\n",
                "udf_kwargs = {\n",
                "    'concat_tokens': 4,\n",
                "    'tokenizer': 'EleutherAI/gpt-neox-20b',\n",
                "    'eos_text': '<|endoftext|>',\n",
                "    'compression': 'zstd',\n",
                "    'split': 'train',\n",
                "    'no_wrap': False,\n",
                "    'bos_text': '',\n",
                "}\n",
                "\n",
                "# Convert the dataset to an MDS format. It fetches sample from dataframe, tokenize it, and then convert to MDS format.\n",
                "# It divides the dataframe into 4 parts, one parts per worker and merge the `index.json` from 4 sub-parts into one in a parent directory.\n",
                "dataframe_to_mds(pdf.repartition(4), merge_index=True, mds_kwargs=mds_kwargs, udf_iterable=pandas_processing_fn, udf_kwargs=udf_kwargs)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "8c3ea1b0-cab4-4d0b-b4a2-ed30dae39959",
                    "showTitle": false,
                    "title": ""
                },
                "id": "Vady_sp9R_HT"
            },
            "source": [
                "Let's check file structures in the output MDS dataset. One can see four directories and one `index.json` file. The `index.json` file contains the meta-data information about all four sub-directories."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {
                        "byteLimit": 2048000,
                        "rowLimit": 10000
                    },
                    "inputWidgets": {},
                    "nuid": "a071fe82-8d1e-4dd4-9e34-8a0dcbee32e0",
                    "showTitle": false,
                    "title": ""
                },
                "id": "VUP5VsILR_HT"
            },
            "outputs": [],
            "source": [
                "%ls {out_path}"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "WVR9ld1vl5VX"
            },
            "outputs": [],
            "source": [
                "%cat {out_path +'/index.json'}"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "f84d61f1-9f09-484c-97cf-1f5ee1e8a214",
                    "showTitle": false,
                    "title": ""
                },
                "id": "q7r3o-mKR_HT"
            },
            "source": [
                "### Load the MDS dataset using StreamingDataset and look at the output."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "application/vnd.databricks.v1+cell": {
                    "cellMetadata": {},
                    "inputWidgets": {},
                    "nuid": "ee73abb8-2261-484d-93d5-722b67090ed6",
                    "showTitle": false,
                    "title": ""
                },
                "id": "gQSmFIq6R_HT"
            },
            "outputs": [],
            "source": [
                "from torch.utils.data import DataLoader\n",
                "import streaming\n",
                "from streaming import StreamingDataset\n",
                "\n",
                "# clean stale shared memory if any\n",
                "streaming.base.util.clean_stale_shared_memory()\n",
                "\n",
                "dataset = StreamingDataset(local=out_path, remote=None, batch_size=2, predownload=4)\n",
                "\n",
                "dataloader = DataLoader(dataset, batch_size=2, num_workers=1)\n",
                "\n",
                "for i, data in enumerate(dataloader):\n",
                "    print(data)\n",
                "    # Display only first 10 batches\n",
                "    if i == 10:\n",
                "        break"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "xcVwdCL_bcg8"
            },
            "source": [
                "## Cleanup\n",
                "\n",
                "That's it. No need to hang on to the files created by the tutorial..."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "Wa_ZGkBckq-b"
            },
            "outputs": [],
            "source": [
                "shutil.rmtree(out_path, ignore_errors=True)\n",
                "shutil.rmtree(local_dir, ignore_errors=True)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "uia53VVabf5P"
            },
            "source": [
                "\n",
                "## What next?\n",
                "\n",
                "You've now seen an in-depth look at how to convert spark dataframe to MDS format and load the same MDS dataset for model training or for your personalized use-case.\n",
                "\n",
                "To continue learning about Streaming, please continue to explore our examples!"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {
                "id": "N0LodDjmbimn"
            },
            "source": [
                "## Come get involved with MosaicML!\n",
                "\n",
                "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
                "\n",
                "### [Star Streaming on GitHub](https://github.com/mosaicml/streaming)\n",
                "\n",
                "Help make others aware of our work by [starring Streaming on GitHub](https://github.com/mosaicml/streaming).\n",
                "\n",
                "### [Join the MosaicML Slack](https://dub.sh/mcomm)\n",
                "\n",
                "Head on over to the [MosaicML slack](https://dub.sh/mcomm) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
                "\n",
                "### Contribute to Streaming\n",
                "\n",
                "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/streaming/issues) or make a [pull request](https://github.com/mosaicml/streaming/pulls)!"
            ]
        }
    ],
    "metadata": {
        "application/vnd.databricks.v1+notebook": {
            "dashboards": [],
            "language": "python",
            "notebookMetadata": {
                "mostRecentlyExecutedCommandWithImplicitDF": {
                    "commandId": -1,
                    "dataframes": [
                        "_sqldf"
                    ]
                },
                "pythonIndentUnit": 2
            },
            "notebookName": "SPARK DataFrame",
            "widgets": {}
        },
        "colab": {
            "provenance": []
        },
        "kernelspec": {
            "display_name": "Python 3",
            "name": "python3"
        },
        "language_info": {
            "name": "python"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 0
}
