{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/16 [00:00<?, ?it/s]\n",
      "Downloading data: 100%|██████████| 20/20 [00:00<00:00, 130257.89files/s]\n",
      "Generating train split: 20 examples [00:00, 1863.43 examples/s]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "import os\n",
    "import sys\n",
    "dirof = os.path.dirname\n",
    "sys.path.append(dirof(dirof(os.path.abspath(__file__))))\n",
    "\n",
    "from utils_common.utils import jpath, ls\n",
    "from datasets import load_dataset\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Load the dataset\n",
    "data_root = '/data2/[anonymous]/Datasets/LAMD_v4/LAMD/REMI'\n",
    "all_remi_fps = []\n",
    "sub_dirs = ls(data_root)\n",
    "for sub_dir in tqdm(sub_dirs):\n",
    "    sub_fp = jpath(data_root, sub_dir)\n",
    "    remi_fns = ls(sub_fp)\n",
    "    all_remi_fps.extend([jpath(sub_fp, fn) for fn in remi_fns][:20])\n",
    "    break\n",
    "dataset = load_dataset(\"text\", data_files={\"train\": all_remi_fps})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text'],\n",
       "        num_rows: 19\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text'],\n",
       "        num_rows: 1\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Split the dataset\n",
    "import datasets\n",
    "dataset_splitted = dataset['train'].train_test_split(test_size=0.002)\n",
    "dataset_splitted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "Map: 100%|██████████| 19/19 [00:00<00:00, 47.85 examples/s]\n",
      "Map: 100%|██████████| 1/1 [00:00<00:00, 97.04 examples/s]\n"
     ]
    }
   ],
   "source": [
    "# Tokenize the dataset\n",
    "\n",
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"/home/[anonymous]/work/[anonymous]/[anonymous]/dataset_preparation/test_tokenizer2\")\n",
    "context_length = 2048 #2048\n",
    "outputs = tokenizer(\n",
    "    dataset[\"train\"][:2][\"text\"],\n",
    "    truncation=True,\n",
    "    max_length=context_length,\n",
    "    return_overflowing_tokens=True,\n",
    "    return_length=True,\n",
    ")\n",
    "\n",
    "def tokenize(element):\n",
    "    outputs = tokenizer(\n",
    "        element[\"text\"],\n",
    "        truncation=True,\n",
    "        max_length=context_length,\n",
    "        padding=\"max_length\",\n",
    "        return_overflowing_tokens=True,\n",
    "        return_length=True,\n",
    "    )\n",
    "    input_batch = []\n",
    "    for length, input_ids in zip(outputs[\"length\"], outputs[\"input_ids\"]):\n",
    "        if length == context_length:\n",
    "            input_batch.append(input_ids)\n",
    "    return {\"input_ids\": input_batch}\n",
    "\n",
    "tokenized_datasets = dataset_splitted.map(\n",
    "    tokenize, batched=True, remove_columns=dataset_splitted[\"train\"].column_names\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[2,\n",
       " 673,\n",
       " 193,\n",
       " 335,\n",
       " 535,\n",
       " 675,\n",
       " 214,\n",
       " 335,\n",
       " 536,\n",
       " 672,\n",
       " 29,\n",
       " 141,\n",
       " 335,\n",
       " 535,\n",
       " 674,\n",
       " 311,\n",
       " 538,\n",
       " 662,\n",
       " 310,\n",
       " 527,\n",
       " 667,\n",
       " 33,\n",
       " 141,\n",
       " 318,\n",
       " 531,\n",
       " 672,\n",
       " 37,\n",
       " 141,\n",
       " 337,\n",
       " 531,\n",
       " 676,\n",
       " 193,\n",
       " 337,\n",
       " 530,\n",
       " 673,\n",
       " 38,\n",
       " 141,\n",
       " 323,\n",
       " 530,\n",
       " 673,\n",
       " 189,\n",
       " 337,\n",
       " 530,\n",
       " 672,\n",
       " 214,\n",
       " 337,\n",
       " 530,\n",
       " 670,\n",
       " 41,\n",
       " 193,\n",
       " 335,\n",
       " 531,\n",
       " 672,\n",
       " 214,\n",
       " 335,\n",
       " 531,\n",
       " 670,\n",
       " 42,\n",
       " 141,\n",
       " 335,\n",
       " 530,\n",
       " 675,\n",
       " 318,\n",
       " 527,\n",
       " 672,\n",
       " 189,\n",
       " 335,\n",
       " 530,\n",
       " 672,\n",
       " 45,\n",
       " 141,\n",
       " 330,\n",
       " 528,\n",
       " 670,\n",
       " 46,\n",
       " 189,\n",
       " 334,\n",
       " 536,\n",
       " 673,\n",
       " 193,\n",
       " 334,\n",
       " 535,\n",
       " 673,\n",
       " 214,\n",
       " 334,\n",
       " 536,\n",
       " 672,\n",
       " 47,\n",
       " 141,\n",
       " 334,\n",
       " 535,\n",
       " 673,\n",
       " 318,\n",
       " 539,\n",
       " 671,\n",
       " 51,\n",
       " 141,\n",
       " 325,\n",
       " 530,\n",
       " 670,\n",
       " 55,\n",
       " 141,\n",
       " 335,\n",
       " 531,\n",
       " 676,\n",
       " 330,\n",
       " 530,\n",
       " 673,\n",
       " 193,\n",
       " 335,\n",
       " 531,\n",
       " 676,\n",
       " 214,\n",
       " 335,\n",
       " 530,\n",
       " 672,\n",
       " 56,\n",
       " 189,\n",
       " 335,\n",
       " 529,\n",
       " 673,\n",
       " 59,\n",
       " 193,\n",
       " 334,\n",
       " 531,\n",
       " 673,\n",
       " 60,\n",
       " 141,\n",
       " 334,\n",
       " 530,\n",
       " 673,\n",
       " 325,\n",
       " 527,\n",
       " 672,\n",
       " 189,\n",
       " 334,\n",
       " 530,\n",
       " 674,\n",
       " 214,\n",
       " 334,\n",
       " 530,\n",
       " 673,\n",
       " 12,\n",
       " 695,\n",
       " 967,\n",
       " 16,\n",
       " 141,\n",
       " 332,\n",
       " 541,\n",
       " 674,\n",
       " 329,\n",
       " 535,\n",
       " 671,\n",
       " 313,\n",
       " 536,\n",
       " 671,\n",
       " 189,\n",
       " 332,\n",
       " 541,\n",
       " 672,\n",
       " 193,\n",
       " 332,\n",
       " 541,\n",
       " 673,\n",
       " 214,\n",
       " 332,\n",
       " 541,\n",
       " 672,\n",
       " 21,\n",
       " 141,\n",
       " 320,\n",
       " 531,\n",
       " 674,\n",
       " 26,\n",
       " 141,\n",
       " 325,\n",
       " 527,\n",
       " 674,\n",
       " 30,\n",
       " 141,\n",
       " 329,\n",
       " 527,\n",
       " 674,\n",
       " 34,\n",
       " 193,\n",
       " 330,\n",
       " 545,\n",
       " 672,\n",
       " 35,\n",
       " 141,\n",
       " 330,\n",
       " 530,\n",
       " 671,\n",
       " 318,\n",
       " 546,\n",
       " 669,\n",
       " 189,\n",
       " 330,\n",
       " 547,\n",
       " 672,\n",
       " 214,\n",
       " 330,\n",
       " 546,\n",
       " 670,\n",
       " 39,\n",
       " 141,\n",
       " 325,\n",
       " 538,\n",
       " 673,\n",
       " 44,\n",
       " 141,\n",
       " 330,\n",
       " 533,\n",
       " 668,\n",
       " 49,\n",
       " 141,\n",
       " 332,\n",
       " 531,\n",
       " 670,\n",
       " 53,\n",
       " 141,\n",
       " 334,\n",
       " 536,\n",
       " 673,\n",
       " 58,\n",
       " 141,\n",
       " 330,\n",
       " 530,\n",
       " 673,\n",
       " 12,\n",
       " 695,\n",
       " 967,\n",
       " 14,\n",
       " 141,\n",
       " 325,\n",
       " 535,\n",
       " 675,\n",
       " 19,\n",
       " 141,\n",
       " 330,\n",
       " 535,\n",
       " 672,\n",
       " 22,\n",
       " 141,\n",
       " 334,\n",
       " 534,\n",
       " 668,\n",
       " 214,\n",
       " 334,\n",
       " 541,\n",
       " 665,\n",
       " 23,\n",
       " 141,\n",
       " 342,\n",
       " 541,\n",
       " 674,\n",
       " 318,\n",
       " 534,\n",
       " 674,\n",
       " 189,\n",
       " 342,\n",
       " 541,\n",
       " 669,\n",
       " 334,\n",
       " 541,\n",
       " 672,\n",
       " 193,\n",
       " 342,\n",
       " 541,\n",
       " 674,\n",
       " 334,\n",
       " 541,\n",
       " 675,\n",
       " 214,\n",
       " 342,\n",
       " 541,\n",
       " 671,\n",
       " 28,\n",
       " 141,\n",
       " 325,\n",
       " 530,\n",
       " 673,\n",
       " 32,\n",
       " 141,\n",
       " 330,\n",
       " 527,\n",
       " 676,\n",
       " 36,\n",
       " 141,\n",
       " 332,\n",
       " 527,\n",
       " 674,\n",
       " 41,\n",
       " 141,\n",
       " 341,\n",
       " 540,\n",
       " 676,\n",
       " 310,\n",
       " 531,\n",
       " 673,\n",
       " 189,\n",
       " 337,\n",
       " 540,\n",
       " 671,\n",
       " 193,\n",
       " 341,\n",
       " 541,\n",
       " 675,\n",
       " 337,\n",
       " 541,\n",
       " 674,\n",
       " 214,\n",
       " 341,\n",
       " 541,\n",
       " 672,\n",
       " 42,\n",
       " 189,\n",
       " 341,\n",
       " 540,\n",
       " 674,\n",
       " 46,\n",
       " 141,\n",
       " 317,\n",
       " 530,\n",
       " 671,\n",
       " 50,\n",
       " 141,\n",
       " 322,\n",
       " 527,\n",
       " 676,\n",
       " 54,\n",
       " 141,\n",
       " 325,\n",
       " 527,\n",
       " 674,\n",
       " 58,\n",
       " 141,\n",
       " 335,\n",
       " 528,\n",
       " 672,\n",
       " 59,\n",
       " 141,\n",
       " 339,\n",
       " 541,\n",
       " 676,\n",
       " 311,\n",
       " 534,\n",
       " 673,\n",
       " 189,\n",
       " 339,\n",
       " 541,\n",
       " 672,\n",
       " 335,\n",
       " 540,\n",
       " 669,\n",
       " 193,\n",
       " 339,\n",
       " 541,\n",
       " 671,\n",
       " 335,\n",
       " 541,\n",
       " 673,\n",
       " 214,\n",
       " 339,\n",
       " 541,\n",
       " 672,\n",
       " 12,\n",
       " 695,\n",
       " 967,\n",
       " 16,\n",
       " 141,\n",
       " 318,\n",
       " 530,\n",
       " 672,\n",
       " 20,\n",
       " 141,\n",
       " 323,\n",
       " 527,\n",
       " 673,\n",
       " 24,\n",
       " 141,\n",
       " 327,\n",
       " 527,\n",
       " 672,\n",
       " 28,\n",
       " 141,\n",
       " 335,\n",
       " 528,\n",
       " 669,\n",
       " 189,\n",
       " 337,\n",
       " 539,\n",
       " 672,\n",
       " 29,\n",
       " 141,\n",
       " 337,\n",
       " 540,\n",
       " 675,\n",
       " 318,\n",
       " 532,\n",
       " 671,\n",
       " 193,\n",
       " 337,\n",
       " 541,\n",
       " 673,\n",
       " 214,\n",
       " 337,\n",
       " 541,\n",
       " 670,\n",
       " 34,\n",
       " 141,\n",
       " 325,\n",
       " 530,\n",
       " 672,\n",
       " 38,\n",
       " 141,\n",
       " 330,\n",
       " 527,\n",
       " 673,\n",
       " 46,\n",
       " 141,\n",
       " 335,\n",
       " 527,\n",
       " 672,\n",
       " 189,\n",
       " 335,\n",
       " 541,\n",
       " 671,\n",
       " 47,\n",
       " 141,\n",
       " 339,\n",
       " 541,\n",
       " 674,\n",
       " 311,\n",
       " 531,\n",
       " 671,\n",
       " 189,\n",
       " 339,\n",
       " 541,\n",
       " 672,\n",
       " 193,\n",
       " 339,\n",
       " 541,\n",
       " 673,\n",
       " 335,\n",
       " 540,\n",
       " 675,\n",
       " 214,\n",
       " 339,\n",
       " 540,\n",
       " 673,\n",
       " 52,\n",
       " 141,\n",
       " 318,\n",
       " 530,\n",
       " 672,\n",
       " 56,\n",
       " 141,\n",
       " 323,\n",
       " 527,\n",
       " 672,\n",
       " 60,\n",
       " 141,\n",
       " 327,\n",
       " 527,\n",
       " 672,\n",
       " 12,\n",
       " 695,\n",
       " 967,\n",
       " 16,\n",
       " 141,\n",
       " 335,\n",
       " 527,\n",
       " 670,\n",
       " 189,\n",
       " 337,\n",
       " 539,\n",
       " 671,\n",
       " 193,\n",
       " 335,\n",
       " 527,\n",
       " 662,\n",
       " 214,\n",
       " 337,\n",
       " 540,\n",
       " 675,\n",
       " 17,\n",
       " 141,\n",
       " 337,\n",
       " 538,\n",
       " 674,\n",
       " 318,\n",
       " 538,\n",
       " 671,\n",
       " 193,\n",
       " 337,\n",
       " 539,\n",
       " 674,\n",
       " 21,\n",
       " 141,\n",
       " 325,\n",
       " 531,\n",
       " 671,\n",
       " 26,\n",
       " 141,\n",
       " 330,\n",
       " 529,\n",
       " 671,\n",
       " 30,\n",
       " 141,\n",
       " 325,\n",
       " 527,\n",
       " 672,\n",
       " 33,\n",
       " 189,\n",
       " 337,\n",
       " 536,\n",
       " 673,\n",
       " 214,\n",
       " 337,\n",
       " 536,\n",
       " 671,\n",
       " 34,\n",
       " 141,\n",
       " 337,\n",
       " 536,\n",
       " 677,\n",
       " 313,\n",
       " 535,\n",
       " 673,\n",
       " 193,\n",
       " 337,\n",
       " 536,\n",
       " 677,\n",
       " 39,\n",
       " 141,\n",
       " 320,\n",
       " 530,\n",
       " 676,\n",
       " 43,\n",
       " 141,\n",
       " 339,\n",
       " 536,\n",
       " 677,\n",
       " 325,\n",
       " 527,\n",
       " 674,\n",
       " 189,\n",
       " 339,\n",
       " 536,\n",
       " 674,\n",
       " 193,\n",
       " 339,\n",
       " 535,\n",
       " 675,\n",
       " 214,\n",
       " 339,\n",
       " 536,\n",
       " 674,\n",
       " 47,\n",
       " 141,\n",
       " 329,\n",
       " 535,\n",
       " 675,\n",
       " 52,\n",
       " 141,\n",
       " 341,\n",
       " 531,\n",
       " 677,\n",
       " 332,\n",
       " 535,\n",
       " 677,\n",
       " 189,\n",
       " 341,\n",
       " 530,\n",
       " 674,\n",
       " 193,\n",
       " 341,\n",
       " 531,\n",
       " 675,\n",
       " 214,\n",
       " 341,\n",
       " 530,\n",
       " 673,\n",
       " 56,\n",
       " 141,\n",
       " 337,\n",
       " 531,\n",
       " 677,\n",
       " 325,\n",
       " 534,\n",
       " 674,\n",
       " 189,\n",
       " 337,\n",
       " 531,\n",
       " 671,\n",
       " 193,\n",
       " 337,\n",
       " 532,\n",
       " 674,\n",
       " 214,\n",
       " 337,\n",
       " 531,\n",
       " 673,\n",
       " 60,\n",
       " 141,\n",
       " 339,\n",
       " 531,\n",
       " 678,\n",
       " 214,\n",
       " 339,\n",
       " 531,\n",
       " 676,\n",
       " 12,\n",
       " 695,\n",
       " 967,\n",
       " 13,\n",
       " 141,\n",
       " 329,\n",
       " 530,\n",
       " 673,\n",
       " 189,\n",
       " 339,\n",
       " 530,\n",
       " 674,\n",
       " 193,\n",
       " 339,\n",
       " 531,\n",
       " 676,\n",
       " 16,\n",
       " 141,\n",
       " 327,\n",
       " 527,\n",
       " 673,\n",
       " 189,\n",
       " 341,\n",
       " 531,\n",
       " 673,\n",
       " 214,\n",
       " 341,\n",
       " 531,\n",
       " 675,\n",
       " 17,\n",
       " 141,\n",
       " 341,\n",
       " 530,\n",
       " 676,\n",
       " 193,\n",
       " 341,\n",
       " 531,\n",
       " 676,\n",
       " 20,\n",
       " 141,\n",
       " 337,\n",
       " 528,\n",
       " 673,\n",
       " 21,\n",
       " 141,\n",
       " 342,\n",
       " 541,\n",
       " 675,\n",
       " 318,\n",
       " 533,\n",
       " 672,\n",
       " 189,\n",
       " 342,\n",
       " 540,\n",
       " 674,\n",
       " 214,\n",
       " 342,\n",
       " 540,\n",
       " 675,\n",
       " 22,\n",
       " 193,\n",
       " 342,\n",
       " 540,\n",
       " 677,\n",
       " 25,\n",
       " 141,\n",
       " 325,\n",
       " 531,\n",
       " 673,\n",
       " 29,\n",
       " 141,\n",
       " 330,\n",
       " 528,\n",
       " 674,\n",
       " 34,\n",
       " 141,\n",
       " 332,\n",
       " 527,\n",
       " 676,\n",
       " 38,\n",
       " 141,\n",
       " 341,\n",
       " 540,\n",
       " 677,\n",
       " 310,\n",
       " 531,\n",
       " 672,\n",
       " 189,\n",
       " 341,\n",
       " 541,\n",
       " 673,\n",
       " 337,\n",
       " 539,\n",
       " 668,\n",
       " 214,\n",
       " 341,\n",
       " 541,\n",
       " 676,\n",
       " 39,\n",
       " 193,\n",
       " 341,\n",
       " 540,\n",
       " 677,\n",
       " 337,\n",
       " 539,\n",
       " 675,\n",
       " 43,\n",
       " 141,\n",
       " 317,\n",
       " 530,\n",
       " 671,\n",
       " 47,\n",
       " 141,\n",
       " 322,\n",
       " 527,\n",
       " 676,\n",
       " 51,\n",
       " 141,\n",
       " 325,\n",
       " 527,\n",
       " 674,\n",
       " 54,\n",
       " 141,\n",
       " 335,\n",
       " 528,\n",
       " 673,\n",
       " 55,\n",
       " 141,\n",
       " 311,\n",
       " 532,\n",
       " 671,\n",
       " 214,\n",
       " 339,\n",
       " 541,\n",
       " 670,\n",
       " 56,\n",
       " 141,\n",
       " 339,\n",
       " 540,\n",
       " 677,\n",
       " 189,\n",
       " 339,\n",
       " 540,\n",
       " 673,\n",
       " 335,\n",
       " 538,\n",
       " 667,\n",
       " 193,\n",
       " 339,\n",
       " 541,\n",
       " 675,\n",
       " 335,\n",
       " 541,\n",
       " 673,\n",
       " 60,\n",
       " 141,\n",
       " 318,\n",
       " 530,\n",
       " 674,\n",
       " 12,\n",
       " 695,\n",
       " 967,\n",
       " 16,\n",
       " 141,\n",
       " 323,\n",
       " 527,\n",
       " 674,\n",
       " 20,\n",
       " 141,\n",
       " 327,\n",
       " 527,\n",
       " 673,\n",
       " 24,\n",
       " 141,\n",
       " 335,\n",
       " 527,\n",
       " 665,\n",
       " 189,\n",
       " 335,\n",
       " 528,\n",
       " 665,\n",
       " 25,\n",
       " 141,\n",
       " 337,\n",
       " 538,\n",
       " 675,\n",
       " 318,\n",
       " 532,\n",
       " 673,\n",
       " 189,\n",
       " 337,\n",
       " 540,\n",
       " 669,\n",
       " 214,\n",
       " 337,\n",
       " 540,\n",
       " 670,\n",
       " 26,\n",
       " 193,\n",
       " 337,\n",
       " 540,\n",
       " 674,\n",
       " 29,\n",
       " 141,\n",
       " 325,\n",
       " 530,\n",
       " 672,\n",
       " 33,\n",
       " 141,\n",
       " 330,\n",
       " 527,\n",
       " 671,\n",
       " 37,\n",
       " 141,\n",
       " 332,\n",
       " 527,\n",
       " 673,\n",
       " 42,\n",
       " 141,\n",
       " 335,\n",
       " 535,\n",
       " 674,\n",
       " 330,\n",
       " 531,\n",
       " 673,\n",
       " 311,\n",
       " 531,\n",
       " 673,\n",
       " 189,\n",
       " 335,\n",
       " 535,\n",
       " 672,\n",
       " 193,\n",
       " 335,\n",
       " 536,\n",
       " 675,\n",
       " 214,\n",
       " 335,\n",
       " 535,\n",
       " 669,\n",
       " 47,\n",
       " 141,\n",
       " 318,\n",
       " 530,\n",
       " 672,\n",
       " 51,\n",
       " 141,\n",
       " 337,\n",
       " 530,\n",
       " 674,\n",
       " 330,\n",
       " 530,\n",
       " 673,\n",
       " 323,\n",
       " 530,\n",
       " 673,\n",
       " 189,\n",
       " 337,\n",
       " 530,\n",
       " 673,\n",
       " 193,\n",
       " 337,\n",
       " 531,\n",
       " 674,\n",
       " 214,\n",
       " 337,\n",
       " 530,\n",
       " 668,\n",
       " 55,\n",
       " 141,\n",
       " 335,\n",
       " 530,\n",
       " 674,\n",
       " 318,\n",
       " 527,\n",
       " 673,\n",
       " 189,\n",
       " 335,\n",
       " 530,\n",
       " 671,\n",
       " 193,\n",
       " 335,\n",
       " 531,\n",
       " 673,\n",
       " 214,\n",
       " 335,\n",
       " 531,\n",
       " 670,\n",
       " 59,\n",
       " 141,\n",
       " 334,\n",
       " 536,\n",
       " 675,\n",
       " 330,\n",
       " 532,\n",
       " 673,\n",
       " 306,\n",
       " 531,\n",
       " 673,\n",
       " 189,\n",
       " 334,\n",
       " 535,\n",
       " 672,\n",
       " 193,\n",
       " 334,\n",
       " 536,\n",
       " 675,\n",
       " 214,\n",
       " 334,\n",
       " 536,\n",
       " 670,\n",
       " 12,\n",
       " 695,\n",
       " 967,\n",
       " 16,\n",
       " 141,\n",
       " 313,\n",
       " 530,\n",
       " 671,\n",
       " 20,\n",
       " 141,\n",
       " 335,\n",
       " 531,\n",
       " 675,\n",
       " 330,\n",
       " 528,\n",
       " 666,\n",
       " 318,\n",
       " 527,\n",
       " 675,\n",
       " 189,\n",
       " 335,\n",
       " 530,\n",
       " 672,\n",
       " 21,\n",
       " 193,\n",
       " 335,\n",
       " 530,\n",
       " 675,\n",
       " 214,\n",
       " 335,\n",
       " 529,\n",
       " 673,\n",
       " 24,\n",
       " 141,\n",
       " 334,\n",
       " 531,\n",
       " 673,\n",
       " 322,\n",
       " 527,\n",
       " 672,\n",
       " 189,\n",
       " 334,\n",
       " 531,\n",
       " 672,\n",
       " 193,\n",
       " 334,\n",
       " 531,\n",
       " 672,\n",
       " 214,\n",
       " 334,\n",
       " 531,\n",
       " 669,\n",
       " 28,\n",
       " 141,\n",
       " 329,\n",
       " 530,\n",
       " 669,\n",
       " 29,\n",
       " 141,\n",
       " 332,\n",
       " 540,\n",
       " 673,\n",
       " 313,\n",
       " 533,\n",
       " 670,\n",
       " 189,\n",
       " 332,\n",
       " 540,\n",
       " 672,\n",
       " 193,\n",
       " 332,\n",
       " 541,\n",
       " 674,\n",
       " 214,\n",
       " 332,\n",
       " 541,\n",
       " ...]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenize(dataset_splitted['train'])['input_ids'][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input_ids'],\n",
       "        num_rows: 133\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input_ids'],\n",
       "        num_rows: 3\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/[anonymous]/programs/miniconda3/envs/[anonymous]_hf/lib/python3.8/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPT-2 size: 87.4M parameters\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig, AutoModelForCausalLM\n",
    "import torch\n",
    "\n",
    "config = AutoConfig.from_pretrained(\n",
    "    \"gpt2\",\n",
    "    vocab_size=len(tokenizer),\n",
    "    n_ctx=context_length,\n",
    "    n_positions=context_length,\n",
    "    bos_token_id=tokenizer.bos_token_id,\n",
    "    eos_token_id=tokenizer.eos_token_id,\n",
    "    n_embd=768,\n",
    "    n_head=16,\n",
    "    n_layer=12, #24\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    attn_implementation=\"flash_attention_2\",\n",
    ")\n",
    "# model = GPT2LMHeadModel(config).half()\n",
    "model = AutoModelForCausalLM.from_config(config)\n",
    "model.save_pretrained(\"test_model\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\"test_model\", torch_dtype=torch.bfloat16)\n",
    "# model = GPT2LMHeadModel(config)\n",
    "model_size = sum(t.numel() for t in model.parameters())\n",
    "print(f\"GPT-2 size: {model_size/1000**2:.1f}M parameters\")\n",
    "\n",
    "from transformers import DataCollatorForLanguageModeling\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
    "\n",
    "a = 2\n",
    "\n",
    "from transformers import Trainer, TrainingArguments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPT2Config {\n",
       "  \"_name_or_path\": \"gpt2\",\n",
       "  \"activation_function\": \"gelu_new\",\n",
       "  \"architectures\": [\n",
       "    \"GPT2LMHeadModel\"\n",
       "  ],\n",
       "  \"attn_pdrop\": 0.1,\n",
       "  \"bos_token_id\": 2,\n",
       "  \"embd_pdrop\": 0.1,\n",
       "  \"eos_token_id\": 1,\n",
       "  \"initializer_range\": 0.02,\n",
       "  \"layer_norm_epsilon\": 1e-05,\n",
       "  \"model_type\": \"gpt2\",\n",
       "  \"n_ctx\": 2048,\n",
       "  \"n_embd\": 768,\n",
       "  \"n_head\": 16,\n",
       "  \"n_inner\": null,\n",
       "  \"n_layer\": 12,\n",
       "  \"n_positions\": 2048,\n",
       "  \"reorder_and_upcast_attn\": false,\n",
       "  \"resid_pdrop\": 0.1,\n",
       "  \"scale_attn_by_inverse_layer_idx\": false,\n",
       "  \"scale_attn_weights\": true,\n",
       "  \"summary_activation\": null,\n",
       "  \"summary_first_dropout\": 0.1,\n",
       "  \"summary_proj_to_labels\": true,\n",
       "  \"summary_type\": \"cls_index\",\n",
       "  \"summary_use_proj\": true,\n",
       "  \"task_specific_params\": {\n",
       "    \"text-generation\": {\n",
       "      \"do_sample\": true,\n",
       "      \"max_length\": 50\n",
       "    }\n",
       "  },\n",
       "  \"torch_dtype\": \"bfloat16\",\n",
       "  \"transformers_version\": \"4.40.0.dev0\",\n",
       "  \"use_cache\": true,\n",
       "  \"vocab_size\": 989\n",
       "}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.bfloat16"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    }
   ],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=\"[anonymous]\",\n",
    "    per_device_train_batch_size=3,\n",
    "    per_device_eval_batch_size=3,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    eval_steps=1,\n",
    "    logging_steps=5,\n",
    "    gradient_accumulation_steps=8,\n",
    "    num_train_epochs=1,\n",
    "    weight_decay=0.1,\n",
    "    warmup_steps=1_000,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    learning_rate=5e-4,\n",
    "    save_steps=5_000,\n",
    "    push_to_hub=True,\n",
    "    bf16=True,\n",
    "    # fp16=True,\n",
    "    seed=42,\n",
    ")\n",
    "\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"test\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/[anonymous]/programs/miniconda3/envs/[anonymous]_hf/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1' max='1' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1/1 00:00, Epoch 0/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>6.687500</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1, training_loss=6.98828125, metrics={'train_runtime': 4.5235, 'train_samples_per_second': 29.402, 'train_steps_per_second': 0.221, 'total_flos': 100336140288000.0, 'train_loss': 6.98828125, 'epoch': 0.67})"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "events.out.tfevents.1718859678.smc-gpu3.126680.0:   0%|          | 0.00/4.86k [00:00<?, ?B/s]\n",
      "\u001b[A\n",
      "\n",
      "\n",
      "\u001b[A\u001b[A\u001b[A\n",
      "\n",
      "\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "\n",
      "events.out.tfevents.1718859732.smc-gpu3.126680.1: 100%|██████████| 5.07k/5.07k [00:00<00:00, 18.4kB/s]\n",
      "events.out.tfevents.1718860236.smc-gpu3.126680.3: 100%|██████████| 5.49k/5.49k [00:00<00:00, 18.7kB/s]\n",
      "events.out.tfevents.1718860116.smc-gpu3.126680.2: 100%|██████████| 5.07k/5.07k [00:00<00:00, 16.6kB/s]\n",
      "events.out.tfevents.1718859678.smc-gpu3.126680.0: 100%|██████████| 4.86k/4.86k [00:00<00:00, 14.5kB/s]\n",
      "events.out.tfevents.1718860332.smc-gpu3.126680.4:   0%|          | 0.00/4.87k [00:00<?, ?B/s]\n",
      "\u001b[A\n",
      "\n",
      "\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "\n",
      "events.out.tfevents.1718860332.smc-gpu3.126680.4: 100%|██████████| 4.87k/4.87k [00:00<00:00, 17.5kB/s]\n",
      "events.out.tfevents.1718860528.smc-gpu3.132607.0: 100%|██████████| 4.87k/4.87k [00:00<00:00, 18.2kB/s]\n",
      "events.out.tfevents.1718860560.smc-gpu3.132607.1: 100%|██████████| 4.87k/4.87k [00:00<00:00, 18.1kB/s]\n",
      "events.out.tfevents.1718860613.smc-gpu3.133223.0:   0%|          | 0.00/7.56k [00:00<?, ?B/s]\n",
      "\n",
      "\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "\n",
      "events.out.tfevents.1718860400.smc-gpu3.131587.0: 100%|██████████| 5.28k/5.28k [00:00<00:00, 5.87kB/s]\n",
      "events.out.tfevents.1718860613.smc-gpu3.133223.0: 100%|██████████| 7.56k/7.56k [00:00<00:00, 26.7kB/s]\n",
      "events.out.tfevents.1718860997.smc-gpu3.133223.1: 100%|██████████| 4.87k/4.87k [00:00<00:00, 18.2kB/s]\n",
      "events.out.tfevents.1718861041.smc-gpu3.136973.0: 100%|██████████| 4.87k/4.87k [00:00<00:00, 17.3kB/s]\n",
      "events.out.tfevents.1718861061.smc-gpu3.136973.1:   0%|          | 0.00/6.94k [00:00<?, ?B/s]\n",
      "\u001b[A\n",
      "\n",
      "\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "\n",
      "events.out.tfevents.1718861061.smc-gpu3.136973.1: 100%|██████████| 6.94k/6.94k [00:00<00:00, 28.2kB/s]\n",
      "events.out.tfevents.1718861672.smc-gpu3.136973.2: 100%|██████████| 4.87k/4.87k [00:00<00:00, 16.7kB/s]\n",
      "events.out.tfevents.1718862701.smc-gpu3.148769.0: 100%|██████████| 4.87k/4.87k [00:00<00:00, 17.4kB/s]\n",
      "events.out.tfevents.1718861707.smc-gpu3.139838.0: 100%|██████████| 12.8k/12.8k [00:00<00:00, 37.2kB/s]\n",
      "events.out.tfevents.1718862781.smc-gpu3.149973.0:   0%|          | 0.00/4.87k [00:00<?, ?B/s]\n",
      "\u001b[A\n",
      "\n",
      "\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "\n",
      "events.out.tfevents.1718862781.smc-gpu3.149973.0: 100%|██████████| 4.87k/4.87k [00:00<00:00, 17.6kB/s]\n",
      "events.out.tfevents.1718863626.smc-gpu3.149973.1: 100%|██████████| 7.76k/7.76k [00:00<00:00, 26.4kB/s]\n",
      "events.out.tfevents.1718864270.smc-gpu3.149973.3: 100%|██████████| 5.28k/5.28k [00:00<00:00, 19.7kB/s]\n",
      "events.out.tfevents.1718864221.smc-gpu3.149973.2: 100%|██████████| 5.90k/5.90k [00:00<00:00, 18.3kB/s]\n",
      "events.out.tfevents.1718864318.smc-gpu3.149973.4:   0%|          | 0.00/4.87k [00:00<?, ?B/s]\n",
      "\u001b[A\n",
      "\n",
      "\u001b[A\u001b[A\n",
      "\n",
      "\n",
      "\n",
      "events.out.tfevents.1718864318.smc-gpu3.149973.4: 100%|██████████| 4.87k/4.87k [00:00<00:00, 18.2kB/s]\n",
      "events.out.tfevents.1718864374.smc-gpu3.161446.0: 100%|██████████| 4.87k/4.87k [00:00<00:00, 19.0kB/s]\n",
      "events.out.tfevents.1718864441.smc-gpu3.161972.0: 100%|██████████| 5.90k/5.90k [00:00<00:00, 23.5kB/s]\n",
      "events.out.tfevents.1718875248.smc-gpu3.205648.0: 100%|██████████| 20.6k/20.6k [00:00<00:00, 64.7kB/s]\n",
      "events.out.tfevents.1718877306.smc-gpu3.205648.1:   0%|          | 0.00/5.21k [00:00<?, ?B/s]\n",
      "events.out.tfevents.1718877306.smc-gpu3.205648.1: 100%|██████████| 5.21k/5.21k [00:00<00:00, 16.8kB/s]\n",
      "training_args.bin: 100%|██████████| 4.47k/4.47k [00:00<00:00, 15.4kB/s]\n",
      "model.safetensors: 100%|██████████| 175M/175M [00:13<00:00, 12.5MB/s]\n",
      "\n",
      "\n",
      "\n",
      "Upload 26 LFS files: 100%|██████████| 26/26 [00:14<00:00,  1.76it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/[anonymous]/[anonymous]/commit/149c9bae577603b8d33903e4337f816b71950ab6', commit_message='End of training', commit_description='', oid='149c9bae577603b8d33903e4337f816b71950ab6', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/[anonymous]/programs/miniconda3/envs/[anonymous]_hf/lib/python3.8/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n"
     ]
    }
   ],
   "source": [
    "len(tokenized_datasets['train'][0]['input_ids'])\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "model1 = AutoModelForCausalLM.from_pretrained(\"gpt2\", torch_dtype=torch.float16, attn_implementation=\"flash_attention_2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'model' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[28], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mempty_cache()\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m model\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m trainer\n",
      "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
     ]
    }
   ],
   "source": [
    "torch.cuda.empty_cache()\n",
    "del model\n",
    "del trainer"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "[anonymous]_hf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
