{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%reload_ext autoreload\n",
    "import transformers\n",
    "from torch import nn\n",
    "import torch\n",
    "from be_great.multihead_models import *\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from transformers import AutoTokenizer\n",
    "from torch.optim import AdamW\n",
    "from torch.optim.lr_scheduler import LinearLR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Distil GPT2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPT2LMHeadModel(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-5): 6 x GPT2Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPT2Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): GPT2MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (act): NewGELUActivation()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dgpt2 = transformers.AutoModelForCausalLM.from_pretrained('distilgpt2')\n",
    "dgpt2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MOEModelForCausalLM(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-5): 6 x GPT2Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPT2Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MultiLayer(\n",
       "          (layers): ModuleList(\n",
       "            (0-2): 3 x GPT2MLP(\n",
       "              (c_fc): Conv1D()\n",
       "              (c_proj): Conv1D()\n",
       "              (act): NewGELUActivation()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): MultiLayer(\n",
       "    (layers): ModuleList(\n",
       "      (0-2): 3 x Linear(in_features=768, out_features=50257, bias=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_experts = 3\n",
    "dgpt2copy = MOEModelForCausalLM(dgpt2, num_experts=num_experts, multihead=True)\n",
    "dgpt2copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"distilgpt2\")\n",
    "instring = 'hello'\n",
    "\n",
    "toks = tokenizer(instring, return_tensors='pt')\n",
    "outs = dgpt2.forward(**toks)\n",
    "outsmoe = dgpt2copy.forward(**toks)\n",
    "assert outs.logits.equal(outsmoe[0])\n",
    "assert dgpt2copy.generate(**toks).equal(dgpt2.generate(**toks))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['hello The U.S. Department of Justice has been investigating the death of a man who was shot']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"distilgpt2\")\n",
    "instring = 'hello'\n",
    "\n",
    "toks = tokenizer(instring, return_tensors='pt')\n",
    "tokenizer.batch_decode(dgpt2.generate(**toks))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Llama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:15<00:00,  7.57s/it]\n",
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaSdpaAttention(\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (rotary_emb): LlamaRotaryEmbedding()\n",
       "        )\n",
       "        (mlp): LlamaMLP(\n",
       "          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
       "          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
       "          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm()\n",
       "        (post_attention_layernorm): LlamaRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llama = transformers.AutoModelForCausalLM.from_pretrained('/zoo/llama2/llama2-7b-hf', torch_dtype=torch.bfloat16)\n",
    "llama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "deep copied model\n",
      "added MOE MLPs\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "MOEModelForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayer(\n",
       "        (self_attn): LlamaSdpaAttention(\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (rotary_emb): LlamaRotaryEmbedding()\n",
       "        )\n",
       "        (mlp): MOEMLP(\n",
       "          (mlps): ModuleList(\n",
       "            (0): LlamaMLP(\n",
       "              (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
       "              (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
       "              (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
       "              (act_fn): SiLU()\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (input_layernorm): LlamaRMSNorm()\n",
       "        (post_attention_layernorm): LlamaRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): LlamaRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llamacopy = MOEModelForCausalLM.from_other(llama)\n",
    "llamacopy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first assert passed\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llama generation complete\n",
      "llamacopy generation complete\n",
      "second assert passed\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer \n",
    "tokenizerllama = AutoTokenizer.from_pretrained(\"/zoo/llama2/llama2-7b-hf\")\n",
    "llama.cpu(), llamacopy.cpu()\n",
    "\n",
    "instring = 'hello'\n",
    "toks = tokenizerllama(instring, return_tensors='pt')\n",
    "outs = llama.forward(**toks)\n",
    "outsmoe = llamacopy.forward(**toks)\n",
    "assert outs.logits.equal(outsmoe[0])\n",
    "print('first assert passed')\n",
    "\n",
    "outs = llama.cuda().generate(input_ids=toks['input_ids'].cuda(), attention_mask=toks['attention_mask'].cuda())\n",
    "print('llama generation complete')\n",
    "outsmoe = llamacopy.cuda().generate(input_ids=toks['input_ids'].cuda(), attention_mask=toks['attention_mask'].cuda())\n",
    "print('llamacopy generation complete')\n",
    "assert outsmoe.equal(outs)\n",
    "print('second assert passed')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = dgpt2copy # don't forget to change tokenizer name too\n",
    "\n",
    "model.train()\n",
    "# Move the model to the device (GPU if available)\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"distilgpt2\")\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data stuff\n",
    "# Load the dataset\n",
    "file_path = '/hdd3/sonia/data/adult.csv'  # Update this with the correct path\n",
    "data = pd.read_csv(file_path)\n",
    "\n",
    "# Preprocess the data: Convert each row to a string\n",
    "def row_to_string(row):\n",
    "    return \", \".join([f\"{col} is {val}\" for col, val in row.items()]) + \".\"\n",
    "def row_to_sentences(row):\n",
    "    return '. '.join([str(col).strip() + \" is \" + str(val).strip() for col, val in zip(row.index, row.values)])\n",
    "def row_to_col_sentences(row):\n",
    "    return [str(col).strip() + \" is \" + str(val).strip() + '.' for col, val in zip(row.index, row.values)]\n",
    "\n",
    "class TextDataset(Dataset):\n",
    "    def __init__(self, texts, tokenizer, num_experts, max_col_length=10):\n",
    "        self.texts = texts\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_col_length = max_col_length\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.texts)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        text = self.texts[idx][:num_experts] # ['age is 39', 'workclass is State-gov', ...]\n",
    "        tokenized_text = self.tokenizer(text, truncation=True, max_length=self.max_col_length, padding='max_length', return_tensors=\"pt\")\n",
    "        return tokenized_text.input_ids.squeeze(), tokenized_text.attention_mask.squeeze()\n",
    "\n",
    "\n",
    "text_data = data.apply(row_to_col_sentences, axis=1).tolist()\n",
    "dataset = TextDataset(text_data, tokenizer, num_experts)\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm \n",
    "\n",
    "# Set up the optimizer and learning rate scheduler\n",
    "optimizer = AdamW(dgpt2.parameters(), lr=5e-2)\n",
    "num_training_steps = len(dataloader) * 3  # Number of epochs\n",
    "lr_scheduler = LinearLR(optimizer, total_iters=num_training_steps)\n",
    "\n",
    "ins = tokenizer(tokenizer.bos_token, return_tensors='pt')\n",
    "\n",
    "losses = []\n",
    "for epoch in range(3):  # Train for 3 epochs\n",
    "    for batch in tqdm(dataloader):\n",
    "        print(batch)\n",
    "        labels, labels_mask = batch\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        outputs = model(ins['input_ids'].to(device), ins['attention_mask'].to(device), labels=labels)\n",
    "        loss = outputs.loss\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "\n",
    "        losses.append(loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fd1944c6f50>]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA2FElEQVR4nO3deXxU1cH/8e9MVggzbCEkhASRVaBsAdEYRNGgRisiGhfaIqJo6AJaN/wpYgvBthJxA3wUUajUB2vdqnkaoWgRghiqQtgUEYEAgbAlQMhCzu8PyJAxATMx5ID38369zsvkzrl3zswFz5dzzz3XJckIAADAErftBgAAAGcjjAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwKth2A2qrTZs2Kioqst0MAAAQAI/Ho+3bt5+yzlkRRtq0aaO8vDzbzQAAAHUQGxt7ykByVoSRyhGR2NhYRkcAADhLeDwe5eXl/WDffVaEkUpFRUWEEQAAfmKYwAoAAKwijAAAAKvqFEbS0tK0adMmFRcXKycnR0lJSSetO2fOHBljqpXc3Nw6NxoAAPx0BBxGUlNTNX36dE2ZMkV9+vTRkiVLlJmZqbi4uBrrjxs3TtHR0b7Stm1b7dmzR2+88caPbjwAAPhpMIGU5cuXmxkzZvhtW7t2rUlPT6/V/kOHDjVHjx418fHxtX5Pj8djjDHG4/EE1FYKhUKhUCj2Sm3774BGRkJCQpSQkKCsrCy/7VlZWUpMTKzVMUaPHq2FCxdqy5YtJ60TGhoqj8fjVwAAwE9TQGEkMjJSwcHBys/P99uen5+v6OjoH9w/OjpaV111lV566aVT1pswYYIKCwt9hQXPAAD46arTBFZjjN/vLper2raa3Hbbbdq/f7/efvvtU9abOnWqvF6vr8TGxtalmQAA4CwQ0KJnBQUFKi8vrzYKEhUVVW20pCa333675s2bp7KyslPWKy0tVWlpaSBNAwAAZ6mARkbKysq0cuVKJScn+21PTk7WsmXLTrnvoEGD1KlTJ82ePTvwVgIAgJ+0gGbGpqammpKSEjNq1CjTtWtXk5GRYYqKinx3x6Snp5tXX3212n5z58412dnZp3U2LoVCoVAolDOn1Lb/DvjZNAsWLFDLli01ceJExcTEKDc3VykpKb67Y2JiYhQfH++3j9fr1fDhwzVu3LhA3w4AAPzEuXQslZzRPB6PCgsL5fV66/VBeQk/v0ptu3XR6kUfa1PO5/V2XAAAUPv+29HPpjkv6QJd/IubFNulk+2mAADgWI4OI77bkV0uuw0BAMDBCCOSXG7CCAAAthBGJLlEGAEAwBZHhxH5rtIQRgAAsMXRYcSYimM/EEYAALDG4WGEOSMAANjm6DDCZRoAAOxzdBg58aRhwggAALY4O4xUHJszwsgIAAD2ODuMiDkjAADY5ugw4nsqDyMjAABY4+gw4rubhjACAIA1jg4j8q3ACgAAbHF0GOFBeQAA2EcYkeRyO/prAADAKmf3wswZAQDAOkeHkROXaey2AwAAJ3N2GKmonMBKGgEAwBZnhxExZwQAANuc3Qv75oxYbgcAAA7m6DDCrb0AANjn7DDCnBEAAKxzdBgR64wAAGCdo3vhygmsDIwAAGCPo8MIi54BAGCfo8NI5fxVJrACAGCPs8NIRYUkJrACAGCTs8OIb9EzwggAALY4Ooz45q9ymQYAAGscHUZY9AwAAPucHUYq54wQRgAAsMbRYYRbewEAsM/RYaTyzl4u0wAAYI+zwwgjIwAAWOfsMHJ8zgjLjAAAYI+jw0jldRq3y9lfAwAANjm6F+bWXgAA7HN0GKkcGmHOCAAA9jg6jJwYGbHbDgAAnMzZYaSCkREAAGxzdBg5seiZs78GAABsqlMvnJaWpk2bNqm4uFg5OTlKSko6Zf3Q0FBNnjxZmzdv1pEjR7Rx40aNGjWqTg2uT1ymAQDAvuBAd0hNTdX06dM1duxYLV26VHfddZcyMzPVrVs3bd26tcZ9FixYoNatW2v06NHauHGjoqKiFBwc8FvXO8MEVgAAzggmkLJ8+XIzY8YMv21r16416enpNda/4oorzL59+0zz5s0Dep+qxePxGGOM8Xg8dT5GTWXgL24y01ZnmxFPTKrX41IoFAqFQql9/x3QZZqQkBAlJCQoKyvLb3tWVpYSExNr3Ofaa69VTk6OHnjgAW3btk0bNmzQX/7yF4WHh5/0fUJDQ+XxePzKacFy8AAAWBfQtZLIyEgFBwcrPz/fb3t+fr6io6Nr3Ofcc89VUlKSjhw5omHDhikyMlIzZsxQixYtNHr06Br3mTBhgiZNmhRI0+qERc8AALCvThNYfZ34cS6Xq9o23xu43TLGaMSIEfrss8+UmZmpe++9V7fddttJR0emTp0qr9frK7GxsXVpZi0wMgIAgG0BjYwUFBSovLy82ihIVFRUtdGSSjt27FBeXp4KCwt929atWye32622bdtq48aN1fYpLS1VaWlpIE2rk8p1RhgZAQDAnoBGRsrKyrRy5UolJyf7bU9OTtayZctq3Gfp0qVq06aNIiIifNs6d+6so0ePatu2bXVocv0xzBkBAMC6gC/TZGRk6I477tCoUaPUtWtXZWRkKD4+XrNmzZIkpaen69VXX/XVnz9/vvbs2aM5c+bovPPO08CBA/WXv/xFL7/8so4cOVJ/n6QOCCMAANgX8GIfCxYsUMuWLTVx4kTFxMQoNzdXKSkp2rJliyQpJiZG8fHxvvqHDh1ScnKynn32WeXk5GjPnj1asGCBHnnkkfr7FHVVOc2FMAIAgDV1Wnls5syZmjlzZo2v1bSy6oYNGzRkyJC6vNVpZUyFJEZGAACwydEPZfFdpnETRgAAsMXRYcS36BkPpwEAwBpHhxHDnBEAAKxzdBhhOXgAAOxzdBipnMDKVRoAAOxxeBg59l+X29FfAwAAVjm6FzZMYAUAwDpHh5ETc0YstwMAAAdzdBg5MWeENAIAgC0ODyPH/sucEQAA7HF2L8ycEQAArHN0GDG+oRG77QAAwMkII2LRMwAAbCKMiDkjAADY5OxemMs0AABY5+gwwqJnAADY5+gwwoPyAACwz9Fh5MTdNIQRAABsIYxIcrkJIwAA2OLoMMKiZwAA2OfoMFJ5lYbLNAAA2OPsMFJx7EF5TGAFAMAeZ4cRMWcEAADbHB1GxN00AABY5+gwcmIBVsIIAAC2ODuMMGcEAADrHB1GxJwRAACsc3QY4dZeAADsc3QY4dk0AADY5+gw4ns2DQAAsMbZYYQJrAAAWOfsMHL8vy63o78GAACscnYvzJwRAACsc3QYYc4IAAD2OTuMVDAyAgCAbY4OI77LNMwZAQDAGkf3wpVP7eXRNAAA2OPsMMIEVgAArHN2GKnw3dxrtR0AADiZo8PIiTkjhBEAAGxxdBjhMg0AAPY5OoxUrsFKGAEAwJ46hZG0tDRt2rRJxcXFysnJUVJS0knrDho0SMaYaqVLly51bnR98S16RhgBAMCagMNIamqqpk+frilTpqhPnz5asmSJMjMzFRcXd8r9OnfurOjoaF/5+uuv69zo+sKiZwAA2BdwGLn33ns1e/ZszZ49W+vXr9c999yjrVu3Ki0t7ZT77dq1S/n5+b5ScfyJuTYxZwQAAPsCCiMhISFKSEhQVlaW3/asrCwlJiaect/PP/9c27dv18KFC3XJJZecsm5oaKg8Ho9fOT24TAMAgG0BhZHIyEgFBwcrPz/fb3t+fr6io6Nr3GfHjh268847NXz4cF1//fXasGGDFi1apIEDB570fSZMmKDCwkJfycvLC6SZtXZiyghhBAAAW4LrstP3n3brcrlO+gTcr776Sl999ZXv9+XLlysuLk733XeflixZUuM+U6dOVUZGhu93j8dzWgKJOQMuFQEA4HQBjYwUFBSovLy82ihIVFRUtdGSU1m+fLk6dep00tdLS0tVVFTkV06L4wHKzYPyAACwJqBeuKysTCtXrlRycrLf9uTkZC1btqzWx+nTp4927NgRyFufFtzaCwCAfQFfpsnIyNC8efOUk5Oj7OxsjRkzRvHx8Zo1a5YkKT09XbGxsRo5cqQkady4cdq8ebPWrFmj0NBQ/eIXv9ANN9yg66+/vn4/yY/AnBEAAOwJOIwsWLBALVu21MSJExUTE6Pc3FylpKRoy5YtkqSYmBjFx8f76oeGhurJJ59UbGysiouLtWbNGqWkpCgzM7P+PkUd+W4vJosAAGCNS777W89cHo9HhYWF8nq99Tp/JDK+rSa8/4aKiw7qkcTkH94BAADUWm37b0fP3OTWXgAA7HN0GDmRRuw2AwAAJ3N0GDHm2JwRRkYAALDH4WGk8tk0jv4aAACwytm9MHNGAACwztFhxDBnBAAA6xwdRuS7TEMaAQDAFkeHkQqWgwcAwDpHhxFGRgAAsM/RYcQQRgAAsM7RYURcpgEAwDpHh5HKkRG329FfAwAAVjm6F/bd2gsAAKxxdBhRlTDCvBEAAOxwdBjxGxkhjAAAYAVh5DhGRgAAsIMwchxhBAAAOxwdRsRlGgAArHN0GGFkBAAA+xwdRlT1zl7CCAAAVjg6jBhT4fuZLAIAgB0ODyNVL9M4+qsAAMAaR/fA/vNXGRoBAMAGR4cR/zRirxkAADiZo8OIqag6Z4Q0AgCADc4OI1Vup3Hx5F4AAKxwdg/s99BeRkYAALDB0WHE/24aiw0BAMDBnB1GmDMCAIB1jg4jVTFnBAAAO+iBKzEwAgCAFY4PIxXHL9W4SCMAAFjh+DDiW/iMOSMAAFjh+DBiKo6FESawAgBgB2Hk+GIjLjdhBAAAGxwfRiov0zBnBAAAOxwfRgxzRgAAsIowwpwRAACsIowY5owAAGCT48OI72l5jIwAAGCF48OIYQIrAABWEUYqGBkBAMCmOoWRtLQ0bdq0ScXFxcrJyVFSUlKt9ktMTFRZWZk+//zzurzt6cGcEQAArAo4jKSmpmr69OmaMmWK+vTpoyVLligzM1NxcXGn3M/r9Wru3LlatGhRnRt7OvgWPWNkBAAAKwIOI/fee69mz56t2bNna/369brnnnu0detWpaWlnXK/F154QfPnz1d2dnadG3ta+K7SEEYAALAhoDASEhKihIQEZWVl+W3PyspSYmLiSfe77bbb1KFDBz3++OO1ep/Q0FB5PB6/crqY40/tBQAAdgQURiIjIxUcHKz8/Hy/7fn5+YqOjq5xn44dO+qJJ57QiBEjdPTo0Vq9z4QJE1RYWOgreXl5gTQzICfWGXH8XF4AAKyoUw/sW0L9OJfLVW2bJLndbs2fP1+PPfaYvv7661off+rUqfJ6vb4SGxtbl2bWii+McJkGAAArggOpXFBQoPLy8mqjIFFRUdVGSyTJ4/Gof//+6tOnj5577jlJxwKK2+1WWVmZhgwZosWLF1fbr7S0VKWlpYE07ccjjAAAYEVAYaSsrEwrV65UcnKy3n77bd/25ORkvfPOO9XqFxYWqkePHn7bxo4dq8GDB+uGG27Qt99+W7dW16MTIyOWGwIAgEMFFEYkKSMjQ/PmzVNOTo6ys7M1ZswYxcfHa9asWZKk9PR0xcbGauTIkTLGaM2aNX7779q1S0eOHKm23ZYTE1hJIwAA2BBwGFmwYIFatmypiRMnKiYmRrm5uUpJSdGWLVskSTExMYqPj6/3hp42x6e6uJnACgCAFS75uuMzl8fjUWFhobxer4qKiur12I8ufEfNWkcpI3Wk8tZ9Va/HBgDAyWrbfzMcwN00AABY5fgwUsGcEQAArHJ8GPEtB8+cEQAArHB8D8ytvQAA2OX4MFLlSXl2mwEAgEM5PoyYCiawAgBgE2HEd5nG8V8FAABW0AMzZwQAAKscH0Z8TxsmjQAAYIXjw0gl5owAAGCH48OIb9EzwggAAFY4PoywHDwAAHY5PowYwggAAFY5PoxUIowAAGCH48OIYc4IAABWEUa4TAMAgFWODyMnHk1DGAEAwAbHhxEWPQMAwC7CiDk2Z4QsAgCAHYQRHpQHAIBV9MDHr9IwNAIAgB2ODyPcTQMAgF2ODyMnloO33A4AABzK8WGEu2kAALCLMMIEVgAArHJ8D1y5HLzLzcgIAAA2EEaYwAoAgFWEEd/IiOO/CgAArHB8D2wqGBkBAMAmwohhZAQAAJsc3wNXXqZxE0YAALDC8T1wBXNGAACwyvE9sG/OCLf2AgBgBWGkcmSERc8AALDC8T3wiQmsjIwAAGCD48NIxfHLNG53kOWWAADgTI4PIywHDwCAXYQR7qYBAMAqx/fAJyawMjICAIANhJHKB+UxMgIAgBWO74ErWIEVAACrHN8DM2cEAAC76tQDp6WladOmTSouLlZOTo6SkpJOWveiiy7SJ598ooKCAh0+fFjr1q3T+PHj69reescKrAAA2BUc6A6pqamaPn26xo4dq6VLl+quu+5SZmamunXrpq1bt1arf+jQIT333HNatWqVDh06pKSkJL3wwgs6dOiQXnzxxXr5ED8GK7ACAGCfCaQsX77czJgxw2/b2rVrTXp6eq2P8eabb5q5c+fWur7H4zHGGOPxeAJqa23KDY89aKatzjaX3Tmy3o9NoVAoFIqTS23774CGA0JCQpSQkKCsrCy/7VlZWUpMTKzVMXr37q3ExER9/PHHJ60TGhoqj8fjV06XE5dpGBkBAMCGgHrgyMhIBQcHKz8/3297fn6+oqOjT7nv1q1bdeTIEeXk5Oj555/X7NmzT1p3woQJKiws9JW8vLxAmhkQw900AABYVaceuHJtjkoul6vatu8bOHCg+vXrp7vvvlvjx4/XzTfffNK6U6dOldfr9ZXY2Ni6NLNWuJsGAAC7AprAWlBQoPLy8mqjIFFRUdVGS75v8+bNkqTc3Fy1bt1akyZN0uuvv15j3dLSUpWWlgbStDqr4Nk0AABYFdBwQFlZmVauXKnk5GS/7cnJyVq2bFmtj+NyuRQWFhbIW582vhVYuZsGAAArAr61NyMjQ/PmzVNOTo6ys7M1ZswYxcfHa9asWZKk9PR0xcbGauTIkZKksWPHasuWLVq/fr0kKSkpSffdd5+effbZevwYdXdizggjIwAA2BBwGFmwYIFatmypiRMnKiYmRrm5uUpJSdGWLVskSTExMYqPj/fVd7vdmjp1qtq3b6/y8nJ98803euihh/TCCy/U36f4EU7cTRNkuSUAADiTS8fu8T2jeTweFRYWyuv1qqioqF6PffX4NA0e/St99Op8vffkmTFaAwDAT0Ft+2/HT5SoYJ0RAACscnwPbEzlcvDMGQEAwAbCyPGRERY9AwDADsf3wCx6BgCAXY7vgSsIIwAAWOX4HtiwAisAAFYRRirnjLACKwAAVji+B/bdTcNlGgAArHB8D2xYZwQAAKsc3wMzZwQAALscH0YqfA/Kc/xXAQCAFY7vgStHRsIiIiy3BAAAZ3J8GIn/WTdJUvdLkiy3BAAAZ3J8GOl5+aW2mwAAgKM5PowAAAC7HB9GjDG2mwAAgKM5Poy4XNzSCwCATY4PIwAAwC7Hh5G8DV/ZbgIAAI7m+DDyf8+9aLsJAAA4muPDyJGDhyRJe7Ztt9wSAACcyfFhhGfTAABgF2HE8GwaAABscnwPbCqOrTPicjn+qwAAwArH98CVIyNcpgEAwA7CSOXICJdpAACwwvE9sG8CKyuxAgBgBWHEMDICAIBNju+BK7i1FwAAqxwfRribBgAAuxzfA3M3DQAAdhFGfBNYHf9VAABgheN74MrLNKzACgCAHY7vgblMAwCAXYQRJrACAGCV43tgntoLAIBdhJHji565g4IstwQAAGcijBwfGZFYEh4AABscH0bKSkp9P4eEh1lsCQAAzuT4MFJaXKyKo0clSeFNmlhuDQAAzuP4MCJJR8vLJUlBwcGWWwIAgPMQRiTfyAiTWAEAaHh1CiNpaWnatGmTiouLlZOTo6SkpJPWHTZsmLKysrRr1y4dOHBAy5Yt05AhQ+rc4NOhovxYGHEFkc0AAGhoAfe+qampmj59uqZMmaI+ffpoyZIlyszMVFxcXI31L774Yn344YdKSUlRQkKCFi9erPfee0+9e/f+sW2vNxXH76gJYmQEAAArTCBl+fLlZsaMGX7b1q5da9LT02t9jNzcXPPoo4/Wur7H4zHGGOPxeAJqa23LpI/eN9NWZ5voTh1Oy/EpFAqFQnFiqW3/HdDISEhIiBISEpSVleW3PSsrS4mJibU6hsvlksfj0d69e09aJzQ0VB6Px6+cTpWXaRgZAQCg4QUURiIjIxUcHKz8/Hy/7fn5+YqOjq7VMX7/+98rIiJCCxYsOGmdCRMmqLCw0Ffy8vICaWbAKiqYwAoAgC11mrFZuYR6JZfLVW1bTW6++WZNmjRJN910k3bv3n3SelOnTpXX6/WV2NjYujSz1irvpmECKwAADS+ghTUKCgpUXl5ebRQkKiqq2mjJ96Wmpmr27Nm68cYbtWjRolPWLS0tVWlp6Snr1Ccu0wAAYE9AQwFlZWVauXKlkpOT/bYnJydr2bJlJ93v5ptv1iuvvKJbb71VH3zwQd1aehpV3k3jIowAANDgAl5yNCMjQ/PmzVNOTo6ys7M1ZswYxcfHa9asWZKk9PR0xcbGauTIkZKOBZG5c+dq3LhxWr58uVq3bi1JKi4uVmFhYT1+lLqrvEzDyAgAAA0v4DCyYMECtWzZUhMnTlRMTIxyc3OVkpKiLVu2SJJiYmIUHx/vq3/XXXcpJCREM2bM0IwZM3zbX3nlFY0aNaoePsKPxwqsAADYU6eHscycOVMzZ86s8bXvB4xLL720Lm/RoJjACgCAPfS+YgIrAAA2EUZUZQKrmzACAEBDI4yoypyRYMIIAAANjTAiLtMAAGATYUQnloNnAisAAA2P3lesMwIAgE2EEUkVR5nACgCALYQRMYEVAACbCCOqEkbcfB0AADQ0el+xHDwAADYRRsRlGgAAbCKMqOplGsIIAAANjTCiE3fTcJkGAICGRxipIqZzB9tNAADAcQgjkqLOiZckRTRrarklAAA4D2FE0rdfrJIkFe3Za7klAAA4D2FE0sE9+yRJnS8833JLAABwHsKIpPZ9e0mSvJEtLbcEAADnIYxIahYdZbsJAAA4FmFEJ27tBQAADY8wIslUEEYAALCFMKITK7ACAICGRxiRVMHICAAA1hBGJBnmjAAAYA1hRFLpkSO2mwAAgGMRRiT9+6W5tpsAAIBjEUZ0Yhn4IwcPWW4JAADOQxiRZGQkSS63y3JLAABwHsKIqq4zQhgBAKChEUYkHR8YUVjjRnbbAQCAAxFG5L8CqzeqlcWWAADgPIQRnZgzIkmdL+hvsSUAADgPYeR7KipYGh4AgIZEGJEU2ujEXJGKcsIIAAANiTAiqWDLNt/PVYMJAAA4/QgjOjaBdX/+LknSTX942HJrAABwFsLIcc1aR9luAgAAjkQYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABW1SmMpKWladOmTSouLlZOTo6SkpJOWjc6Olqvvfaa1q9fr6NHj+qpp56qc2NPp3f+/LTv55DwMIstAQDAWQIOI6mpqZo+fbqmTJmiPn36aMmSJcrMzFRcXFyN9cPCwrR7925NmTJFX3755Y9u8Onyyfw3fD+3Pvccew0BAMBhAg4j9957r2bPnq3Zs2dr/fr1uueee7R161alpaXVWP+7777T+PHjNW/ePB04cOBHN/h0qTh6Yhn4GyY+aLElAAA4S0BhJCQkRAkJCcrKyvLbnpWVpcTExHprVGhoqDwej19pSN5WkQ36fgAAOFlAYSQyMlLBwcHKz8/3256fn6/o6Oh6a9SECRNUWFjoK3l5efV27NpoGtWqQd8PAAAnq9MEVmOM3+8ul6vath9j6tSp8nq9vhIbG1tvx66tyPi2Df6eAAA4UUBhpKCgQOXl5dVGQaKioqqNlvwYpaWlKioq8isNrf/Qqxv8PQEAcKKAwkhZWZlWrlyp5ORkv+3JyclatmxZvTbMtn5DU2w3AQBqrcfgi9Wf/2/hLBUc6A4ZGRmaN2+ecnJylJ2drTFjxig+Pl6zZs2SJKWnpys2NlYjR4707dOrVy9JUpMmTdSqVSv16tVLpaWlWrduXT19jPrXrHWUItvFqeC7rbabAgA/aNTTf5IkbVzxX+3bsdNya4DABBxGFixYoJYtW2rixImKiYlRbm6uUlJStGXLFklSTEyM4uPj/fb54osvfD/369dPI0aM0ObNm9W+ffsf1/p69sn8N5R0642+33slD9ail1612CIACExE86aEEZx1Ag4jkjRz5kzNnDmzxtdGjRpVbZvL5arL2zS4dZ9k+4WRlHF3a//OfK385/9ZbBUABOLs+P8tUBXPpqnCHK2otu3WqY9ZaAkA1F7Vf/CdLf/4A6oijFRRUVE9jEhS90sHqk2XTurQv28DtwgAaqFqGHETRnD2qdNlmp+qqkvCV3X7M3/2/fznobcof9PmBmoRAPwwvwDCyAjOQoyMVPHdqjU/WCe2W5dq2zoN6Kfzr7vmdDQJAH6QS1ymwdmNMFJFeUmJplw1/JR1ul50gWI6d5AkNY+Jlsvt1t0vPaub/vj/1KZLp4ZoJgD4q3qZhgmsOAtxmeZ79m7brvt7J+kvX3xS4+sJ11yphGuu9P2+69vvfD8PGH6t3kqfdtrbCABVudxV/l3JyAjOQoyM1KDi6FHd3yepVnWj2rfz/Zx0yw080wZAg6uaP5jAirMRIyMnUVFe82TWHzL0wfFyud1aPHueWsTG6LN3PpAkRbaL09D7x+loebkO7tunN//w55M+XPBUDx7seH6CeiZfqn9mPKfS4iN1aiOAnxZu7cXZjjByCpOHDFNo40aqOHpUD733v7Xap9vFF0mSzku6UJJ08+RHa6y3a9N3ylv/lfZs2ab9+bskSU1aNFfy3ber79VDNGfcQzqQv1t7tm7z2y9t9nOSpOLCImU++0KdPheAnxaXi8s0OLsRRk6h6pLKM0aN1dg5M+rt2EMfGOf3e/Ybb+vCG6/z/f7r4+/13K/u0refr5Ikxf+sm+/1qpeDIuPbKrrjucr993/UJXGA9ufvUv4332rA9T9X0ohUzf71fdq/s/6eqlzJHRR00tuhATQgvzt7CSM4+7gk1Xw94Azi8XhUWFgor9eroqIia+24YuwdOrC7QAOu/7nie3T74R3qye7NW9TqnPgfrljFHy6/VhMXvnts/++2qlW7OEnS3u079J+5r+vg3n3qc1Wyul86UIvnvKbQRuHat32HSo+UaNemzdq27iu5XFJE82bqc+XlWvr6myovK1Pp4WJ5oyJ12/QnFN+jm+Y//Afl/vtjyUhHy8tVXlrqa0O4p4l6X3GZwps00UevvFZ/X0gDaN+np9p06aSlr79puylnJJfbrb5XX6Hvvlytgi3bfngHnFbhTSI0JXuhJGnm6N9o44qVllsEHFPb/pswUkcut1vxP+um3/31RdtNaRAVFRVyu2s333n/zny9++Sz+tWTk33bnh811jfaI0kv//Z+RcbH6eDefVr/Sbaat4nRtrXrFdqokVwul0oOH9b5w36uvlcP0Vvp07Tr2+9882ja9+2lX02borefeEqbv1ilqHPa6etPcyRJIeFheuKzjyRJD/a7RI2bejV29nNavegjvT/9xPOUQsLD1DKurXZ+/Y2kY+ezY/++2rZug867OFEjpk6SJL0wZpy+yl5xys/rDgpSWERjhYSFqXB3Qa2+o+9r0qK5IuPjtG3dBvW8fJA2LFuhQ/v211g3LKKxEq65UqsWLtbBPfsCfq+g4GA1i26tVufEacOyFTIVFQpr3FjhnggdyN9dq2P0v+5q3fzHRyRJ/549V1mzXlYjj0eH9u3X0fLyavVdbrcaez06tP+Ab1twWJiaRUed8snYrTu0V7PWUdqw7NPqx6wytyooOFieyJanZQTwbNDI69HkpVm+36dek+rYJ45HtW+nwoI9OlJ00HZTIMJIg+l3bYpumVLzvBDU3cu/vV+3P/uXWtf/R/o0ffqP9/SnnI98296fPlMDR9wob6tISdKqDxerZ/KlKjtSopDwMF+9hwdcppsnP6KeyZdWO+6HL8zR2v8sVcf+CSovLVVoo3CFhIUp6dYbFd4kQpJUsGWb32Wz1x56TP99/1jHcPPkR9V/aIq+Wfm5Fv3Pqwr3NNEv/vS43EFBeufPT8sdFKSSw4eV8ru71bip1++95933iL7+NEe9r7xcg0f/Up9nLtSKt97TuL/NVnjEsff+w+XXKq57N329/DOVHD7sF8aeHnGHyo6UaMdXGyUdG+XytorUr+fMUCOvp8bv8ekRd+jcPr2U889MHdq7/1iIaOpVcEiIrn/kfjVu6lV4kwjt+Gqj+l59RY3HeG3CJG1c8V95I1sqODRUTVo014DhP1e3iy/SB8/MUsI1V2rxnL/qhokPKjgkRDNu/7WMMbr8zpF69y/PaOfGTb5LgNNWZ0uSXhn/kFYv+ljBoaFq17O7Uh9/WJHxbTXrjt/q609zNPaVGeqQ0EfP/vIubc1dq+GPPqBvcj7Xof37VbR7j5q3idHAX6Rq6etvav2SbPVMvlT7duzU9g1f66rf3qXP3vlAe7fl6XfzX1Jk3LHjhoSHyxij9UuWVZtQ3rZbV/W5KllZM2erbbcuiuvRzTf618jr0S1TJirnvUytyvq3Gjf1qv/Qq/V55od+YbVJi+aK7niuNq5YqfMGJqr3VZfrrakZKjtSogtuGKr1nyxXyaFD8kS20I6vvvHtF94kQu6gILmDg1R2pEQlhw6rkderyUv/5auzbskyvTT2977fb578iNxBQZo/4XFJUkznDircVaBD+w8oKCREHc9P0KaVn0uSRjzxuNZ8tESfvf2+32cODg31G/msVPnnKiQ8TPE9ztPSv70pY4wimjVVRYVRcWGhvK0iNXbODH365jtaPOfY9+QODlKT5s1VuLtALpdLoY0aqe81V+hoWbm++L8PfRP0Yzp3VOGu3XIHB6vk0KFTTtyP7niu7n/r2PEnnD9YpcXFfq936N9X7fv01KIXX5UxRu7gIEU0a6aigj3qNihJzaKjdHDvPq1e+JGax8Zo/878Gm9miO54rkoOH9a+7Sd/OnKPwYPkiWyh7AVv+ba1Oide+7bvVHlpqdr36amyklIdLS/XbdOn6oOnZ8lUVOicPj313pPPyhx/PEn7Pj3VNKqVvvjXopO+16kEh4bqaFnZSW+KqHS6LrsTRiyI7dpZCddepUG/vFlv/2m6rntwvO0mwZKKigq9/shk3Zo+scHe86vsFep84fk1vlZy+LDCGjdusLY41YH83WraupXv97x1Xyn2vM6+33PezdTRsjJ9typXQx+8R2GNG53W9hTuLlCTFs3lDgqSJH29PEedLujne33nN98qukP7Gvd9NOkKHTl4SBVHj+rSUSN0zb2/UUVFhZ4cNkIRLZpr5LQpatKieY37Tr/5do1//WVJ0vYNX/stCDln3IPK/fd/dMeMaTpvYKK2rl2vuG5dA/5sW9es0+xf36dJHx0LTR+9Ol+XjLzVr07J4WL933P/o8Gjf6nS4iNq2baNJGlv3g7t+Pobdb+k5iUcct7NVL9rr5IkPfHzmxTdob3WfPyJmrWO0hVj7/S99v70mQoJC1WH/n214q1/augD49S4qdfv79uOr79RTKcO1dp1qnN/YNduhYSFaX/+LrXp3FHSsXP53rRnVVx0SLdNn6rgkBCVHSlR9t/f1qIXX1Xfq6/Q0AfG6dM339WA4ddqf/4uNWsd5XfcV8Y/pMSbrte6T7LV9aILFBbRWG8/MV2/++v/qKykRI8NSlHZkZJaff+1RRg5A7SMa6ug4CDt+vY7NfJ61KFfH639eKkvfVb+qw8AANte/u39WvNRzQt+1lVt+2/upjmNqt6WW1xYpNx//8fv9QcTBqnH4Iv1zWf/VXBYqJq1jtIt6RP13Ze56jrwQr0/fabWL8lW90uSNHBEqla8/b4+e/ufKtqz9/hlhcF+6fqvDz6miGZNdcENQ6slcQAATiUoNNTaezMy8hPkcrt1Tq8eSpv9vHZv2aqMG0fqaFmZ7/UuF12gfj+/Uqs+XKzbpj8hSVr2v/9QcdFBtenSUaXFR1Ry+LASrrlS/30/S3//45/1qycn+4Y0N3+xWs/flqawiAjd/syfdG5C72pt+OJfi9T7isu0+cvVOqfXz3zbf2gi7PeHuX/Iof0HFNGsaa3rAwBqNu++R+o8N+VkuEwDhYSH6WhZ+SknJfUYfLH278zXtrUbTnksd1CQzk3ora2561Ry+LDfay6XS90vvViFu3erbbeuWvH2+yovKfF7vZHXo5JDh/3utPjFn/+gPlcl64NnZmnRi6/6HbP1ueeokderO55/Uo28HhXuLtCT1/9CUeeeo7SXnlNQSLCKC4v05PBf6tx+vTVi6iR9Mv8NvTU1Q5LU+cLzdbS8XHnrNujIwUPyRrXSPa+/rO0bNiqux3n6z1//V+uXLFN5Wbnu/8dfVVFRoek3jZLL7VZYRGN9+98v1bipVy3atlH73j110S03qGjPHnkjI/WPKU+qx+CL1a5XD80c/Rv1vPxStTonToN+dYskaf6Ex3Xr1Mf8Ps8X/1qk71blKjH1em1bu17NWkepfd9e+uRvf9fGFSu1/pNsJd1yg6659zeSpP934eUyFUYX3nidEm8erhfT7vEtvFdx9KhvDoB0bHG+YQ//vsbr39+vW3K4WCFhoSrYss3vUQZV7dy4SdEdz5Uk/ev5F9WkZQttW7NeA4Zfq6j27bQ3b4eaRUepSYvm2v3dVh3cs1dlpaXqfEF/bVi6XO16/0z/mDKtxvkyk68YpqKCvbolfaJ6X3FZje//Q3Z9+512fvOtel5+SZ32r6uF//OKLh9zW70ec+3HS/WPKU/KE9lC4+bPrtMxapob0FBWvP3PBn9i+YLH0nXV7+6Wp2WLBn3fM8nqRR/rZ5cNqvfjzn/4D1r5Xma9HjOQ/tuc6cXj8RhjjPF4PNbbQjlzijs4yPdz09atrLenauly0QVm6IPjTf+hKSa8SUSNdVxud7Vt5/TuacI9TU56zM4Xnn/K9+10QX8z4YM3TId+fXzbul+SZGK7dq723q3PPcfXtrDGjY0kExQSUqfP63K7TYu2bfy2xXbtbC6/a5QJCg4+6X7hTSJMdKcOxtOyhXEHB5kWbduYRz9820xbnW3CGjc2tz/zZ/OnlR+b4Y/cbzr062Paduty0uOljEsz01Zn+/b1teO8zia647n+7XW5/P4MDR79S/PQe/9rOl/Yv/qfs6Agv9/7XZtipq3ONpfdOdLEntfZXPfQPcbTsoVp5PX61Rs4ItVMW51tojt1MI28XhP/s26+c36y7zm0Ubhp262rGfX0Eybp1htMeJMI43K5fJ9n7JwZZtrqbPPQPxeYHoMHVdvf07KFueLXd5pm0a1NRLOmJq5HN7/P3CdliOl3bYrpe/UQ06ZLJ9/30MjrNdfc82vTukN7ExQcbPpfd7X5+X2/NcMffcBcdPNw8/u/zzUTPnjDXHbHSONyu02Xiy7w/TkNCQ8zwWFhJqp9OxPdqYMJCQ8zXZMuMMMfud94o078vbzyt2PMxEXv+s7RuL/N9v3csm2siWjezFc3sl2cCW0Ubq659zdm2upsc8ltI8zv/vqiGfSrW0755zD2vM7m/rfnm0ey3jLN20SbmM4dzcARqebOmU+Z0c8/aVqdE+97z/Gvv2xiu3Y2A4Zfazr072umrc42j3/8gel79RAz6pk/md5XXOb73lu2jTWR7eLMbdOfMNNWZ5u+11xhzun1M9+xHs78u2nZNtZcMfYOc+fMp0xkfFvTuOmJPw+V3+8Tn33k+3vibRVp+l49xFxz729My7i21T5L6qQJZtrqbNOibRvTuKnX3DJlopm2Ott0vyTJr965/fqYxxb/0/S7NsX3nt6oVub+t14z01Znm55DBpuQ8DDfZwkJDzOjnvmTSf/036bVOfHm/OuuMV0SB5iWbWNNSHhYvfz/z+/PZC37b0ZGAJxRgkJCFBwaopJDh32/V73MeCoRzZuddH2W+tRQ7/NTFRIeJlNh1P+6q7Vh6XLtzdtx0rphjRtXG439MVxut8KbNFFxYeGPPlbbbl20b/tOv/Vz4I/LNAAAwKra9t+1W1ITAADgNCGMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArAq23YBAeDwe200AAAC1VNt++6wII5UfJi8vz3JLAABAoDwej4qKik76ukuSabjm1F2bNm1O+UHqwuPxKC8vT7GxsfV+bASGc3Fm4DycOTgXZwbOw4/n8Xi0ffv2U9Y5K0ZGJP3gB/kxioqK+EN2huBcnBk4D2cOzsWZgfNQd7X53pjACgAArCKMAAAAqxwdRkpKSjRp0iSVlJTYborjcS7ODJyHMwfn4szAeWgYZ80EVgAA8NPk6JERAABgH2EEAABYRRgBAABWEUYAAIBVjg4jaWlp2rRpk4qLi5WTk6OkpCTbTTqrDRw4UO+++67y8vJkjNHQoUOr1XnssceUl5enw4cPa/HixerWrZvf66GhoXrmmWe0e/duHTx4UO+8845iY2P96jRr1kxz587V/v37tX//fs2dO1dNmzY9rZ/tbPLQQw9pxYoVKiwsVH5+vt566y117ty5Wj3Oxel1991368svv9SBAwd04MABLVu2TFdeeaVfHc5Bw3vooYdkjNFTTz3lt51zYZ9xYklNTTUlJSVm9OjRpmvXruapp54yRUVFJi4uznrbztZy5ZVXmj/+8Y9m2LBhxhhjhg4d6vf6Aw88YA4cOGCGDRtmunfvbv72t7+ZvLw806RJE1+dGTNmmK1bt5rLLrvM9O7d2yxatMh8/vnnxu12++p88MEHZtWqVeaCCy4wF1xwgVm1apV59913rX/+M6VkZmaakSNHmm7dupmePXua9957z2zevNk0btyYc9GA5ZprrjFXXXWV6dSpk+nUqZOZPHmyKSkpMd26deMcWCr9+vUzmzZtMl988YV56qmnfNs5F2dEsd4AK2X58uVmxowZftvWrl1r0tPTrbftp1BqCiPbt283DzzwgO/30NBQs2/fPjNmzBgjyXi9XlNSUmJSU1N9dWJiYkx5ebkZMmSIkWS6du1qjDHm/PPP99UZMGCAMcaYzp07W//cZ2KJjIw0xhgzcOBAzoXlsmfPHnP77bdzDiyUiIgIs2HDBnPZZZeZxYsX+4URzoX94sjLNCEhIUpISFBWVpbf9qysLCUmJlpq1U9b+/btFRMT4/edl5aW6uOPP/Z95wkJCQoNDfWrs2PHDuXm5vrqXHjhhdq/f79WrFjhq/Ppp59q//79nLuTqBwm3rt3ryTOhQ1ut1s33XSTIiIilJ2dzTmw4Pnnn9f777+vRYsW+W3nXJwZzpoH5dWnyMhIBQcHKz8/3297fn6+oqOjLbXqp63ye63pO2/Xrp2vTklJifbv31+tTuX+0dHR2rVrV7Xj79q1i3N3EhkZGVqyZInWrFkjiXPRkHr06KHs7GyFh4fr4MGDGjZsmNatW6cLL7xQEuegodx0003q27ev+vfvX+01/j6cGRwZRioZY/x+d7lc1bahftXlO/9+nZrqc+5q9txzz6lnz541Ts7mXJx+GzZsUO/evdWsWTMNHz5cr776qgYNGuR7nXNw+rVt21ZPP/20hgwZcsol3TkXdjnyMk1BQYHKy8urpdWoqKhq6Rj1Y+fOnZJ0yu98586dCgsLU7NmzU5Zp3Xr1tWO36pVK87d9zzzzDO69tprdemllyovL8+3nXPRcMrKyvTNN99o5cqVevjhh/Xll19q3LhxnIMGlJCQoNatW2vlypUqKytTWVmZLrnkEv3ud79TWVmZ73viXNhnfeKKjbJ8+XLz/PPP+21bs2YNE1jrqZxsAuv999/v+z0kJKTGSWI33nijr050dHSNk8T69+/vq3P++eczSex75dlnnzXbtm0zHTt2rPF1zoWdsnDhQjNnzhzOQQOWJk2amO7du/uVFStWmLlz55ru3btzLs6cYr0BVkrlrb2jRo0yXbt2NRkZGaaoqMjEx8dbb9vZWiIiIkyvXr1Mr169jDHGjB8/3vTq1ct3u/QDDzxg9u3bZ6677jrTvXt389prr9V4+9yWLVvM4MGDTe/evc3ChQtrvH3uiy++MAMGDDADBgwwX375JbfPVSnPP/+82bdvn7n44otN69atfSU8PNxXh3Nx+suUKVNMUlKSadeunenRo4eZPHmyKS8vN5dffjnnwHL5/t00nIszolhvgLWSlpZmvv32W3PkyBGTk5Pjd+sjJfAyaNAgU5PKfwlKMo899pjZvn27KS4uNh999JHvXyaVJSwszDzzzDOmoKDAHDp0yLz77rumbdu2fnWaN29u5s2bZw4cOGAOHDhg5s2bZ5o2bWr9858p5WRGjhzpV49zcXrLSy+95Pv/S35+vvnwww99QYRzYLd8P4xwLuwX1/EfAAAArHDkBFYAAHDmIIwAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACw6v8DQCb+Pj3j1CQAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MOEModelForCausalLM(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-5): 6 x GPT2Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPT2Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MOEMLP(\n",
       "          (mlps): ModuleList(\n",
       "            (0-2): 3 x GPT2MLP(\n",
       "              (c_fc): Conv1D()\n",
       "              (c_proj): Conv1D()\n",
       "              (act): NewGELUActivation()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# autoreload bug"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%reload_ext autoreload\n",
    "import transformers\n",
    "from torch import nn\n",
    "import torch\n",
    "from be_great.multihead_models import MOEModelForCausalLM\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from transformers import AutoTokenizer\n",
    "from torch.optim import AdamW\n",
    "from torch.optim.lr_scheduler import LinearLR\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm import tqdm \n",
    "\n",
    "from IPython.core.interactiveshell import InteractiveShell\n",
    "InteractiveShell.ast_node_interactivity = \"all\"\n",
    "\n",
    "dgpt2 = transformers.AutoModelForCausalLM.from_pretrained('distilgpt2')\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"distilgpt2\")\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "special_tokens_dict = {\"bos_token\": \"<BOS>\", 'eos_token': '<EOS>'}\n",
    "num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)\n",
    "\n",
    "# Preprocess the data: Convert each row to a string\n",
    "def row_to_string(row):\n",
    "    return \", \".join([f\"{col} is {val}\" for col, val in row.items()]) + \".\"\n",
    "def row_to_sentences(row):\n",
    "    return '. '.join([str(col).strip() + \" is \" + str(val).strip() for col, val in zip(row.index, row.values)])\n",
    "def row_to_col_sentences(row):\n",
    "    return [str(col).strip() + \" is \" + str(val).strip() + '.<EOS>' for col, val in zip(row.index, row.values)]\n",
    "\n",
    "class TextDataset(Dataset):\n",
    "    def __init__(self, texts, tokenizer, cols=None, max_col_length=10, do_moe_format=True):\n",
    "        self.texts = texts\n",
    "        self.tokenizer = tokenizer\n",
    "        self.cols = cols # \"None\" for all cols, else a list of desired cols' names\n",
    "        self.max_col_length = max_col_length\n",
    "        self.do_moe_format = do_moe_format\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.texts)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if self.cols is None:\n",
    "            text = row_to_col_sentences(data.iloc[idx])\n",
    "        else:\n",
    "            text = row_to_col_sentences(data[self.cols].iloc[idx]) # ['age is 39.', 'workclass is State-gov.', ...]\n",
    "        if self.do_moe_format:\n",
    "            tokenized_text = self.tokenizer(text, truncation=True, max_length=self.max_col_length, padding='max_length', return_tensors=\"pt\")\n",
    "            prompt = torch.full((1,), #batch_size x token\n",
    "                                self.tokenizer.bos_token_id)\n",
    "            return prompt, tokenized_text.input_ids.squeeze()\n",
    "        else:\n",
    "            text = tokenizer.bos_token + ''.join(text)\n",
    "            # print(text)\n",
    "            tokenized_text = self.tokenizer(text, truncation=True, padding='longest', return_tensors='pt')\n",
    "            return tokenized_text.input_ids.squeeze(), tokenized_text.attention_mask.squeeze()\n",
    "            \n",
    "            \n",
    "# Load the dataset\n",
    "file_path = 'adult.csv'  # Update this with the correct path\n",
    "data = pd.read_csv(file_path)\n",
    "text_data = data.apply(row_to_col_sentences, axis=1).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_experts = 15\n",
    "dgpt2.resize_token_embeddings(len(tokenizer))\n",
    "dgpt2copy = MOEModelForCausalLM(dgpt2, num_experts=num_experts, multihead=True)\n",
    "model = dgpt2copy # don't forget to change tokenizer name and optimizer too\n",
    "\n",
    "model.train()\n",
    "model.set_train_mode()\n",
    "# Move the model to the device (GPU if available)\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "# device = torch.device(\"cpu\")\n",
    "model.to(device)\n",
    "dataset = TextDataset(text_data, tokenizer,)\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "\n",
    "# Set up the optimizer and learning rate scheduler\n",
    "optimizer = AdamW(model.parameters(), lr=5e-6)\n",
    "\n",
    "lossesmoe = []\n",
    "for epoch in range(1):  # Train for 3 epochs\n",
    "    for batch in dataloader:\n",
    "        optimizer.zero_grad()\n",
    "        prompt, labels = batch\n",
    "        prompt = prompt.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        outputs = model.multicol_forward(input_ids=prompt, labels=labels)\n",
    "        # outputs = model.debug_forward(ins['input_ids'].to(device), ins['attention_mask'].to(device), labels=labels)\n",
    "        loss = outputs.loss\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        lossesmoe.append(loss.item())\n",
    "        if len(lossesmoe) % 500 == 0:\n",
    "            torch.save(model.state_dict(), f'./ckpts/moe/dgpt2/adult-1col/debug{len(lossesmoe)}.pt')\n",
    "            try:\n",
    "                plt.close()\n",
    "            except:\n",
    "                pass\n",
    "            plt.plot(lossesmoe)\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPT2LMHeadModel(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-5): 6 x GPT2Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPT2Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): GPT2MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (act): NewGELUActivation()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "GPT2LMHeadModel(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-5): 6 x GPT2Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPT2Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): GPT2MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (act): NewGELUActivation()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "Embedding(50259, 768)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 48842/48842 [16:34<00:00, 49.11it/s]\n",
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
      "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n",
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/transformers/generation/utils.py:1132: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['<|endoftext|>age is 24.<EOS>workclass is Private.<EOS>fnlwgt is 171099']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f530031ba10>]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAGdCAYAAAAIbpn/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgcElEQVR4nO3de3SU1f3v8c+EZLhlAhwxFwLSVEEMFBAItwOmQkO17ZLS/gotR6UXm4orthysKYm1ES+l1XIxCj979acUbF3tomK1bZAD1BouBZWrtVmCgpNkBElIQmIGdJ8/WsYMCdc+k8k3vl9r7bWYZzbPbDa2efNkJo9PkhMAAIBBCfFeAAAAwMUiZAAAgFmEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCBkAAGBWYrwXcLH69eun+vr6eC8DAABcgEAgoMrKSs/OZzJk+vXrp2AwGO9lAACAi5CZmelZzJgMmVNXYjIzM7kqAwCAEYFAQMFg0NOv3SZD5pT6+npCBgCAjzDe7AsAAMwiZAAAgFmEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCBkAAGAWIQMAAMwiZAAAgFmEDAAAMIuQAQAAZhEypxlzw2c0aHxOvJcBAADOg+m7X3stfdDl+soDd0uS7vjEhDivBgAAnMsFX5GZPHmy1q5dq2AwKOecpk+f3mpOSUmJgsGgGhsbtWHDBmVnZ0c97/f7VVpaqsOHD6uhoUHPPPOMMjMzL/5P4ZHeaZfGewkAAOACXHDI9OzZUzt37lRBQUGbzxcWFmr+/PkqKChQTk6OqqurtW7dOiUnJ0fmLFu2TDNmzNCXv/xlTZo0ScnJyfrjH/+ohAS+0wUAAC6Mu9jhnHPTp0+POlZZWekKCwsjj/1+v6upqXH5+flOkktJSXHNzc1u5syZkTkZGRnu5MmTbtq0aef1uoFAwDnnXCAQuOi1tzWGTBrvFu/e7Bbv3uzpeRkMBoPBYMTm67enl0CysrKUkZGhsrKyyLFwOKxNmzZp4sSJkqTRo0fL7/dHzamqqtKePXsic07n9/sVCASiBgAAgKchk56eLkkKhUJRx0OhUOS59PR0NTc3q7a29oxzTldUVKS6urrICAaDXi4bAAAYFZM3pTjnoh77fL5Wx053tjmLFi1SSkpKZHSENwYDAID48zRkqqurJanVlZXU1NTIVZrq6mp17dpVvXv3PuOc04XDYdXX10cNAAAAT0PmwIEDqqqqUl5eXuRYUlKScnNzVV5eLknasWOHwuFw1Jz09HQNGzYsMgcAAOB8XPAPxOvZs6euuOKKyOOsrCyNGDFCR48e1aFDh7Rs2TIVFxeroqJCFRUVKi4uVmNjo1avXi1Jqqur0y9/+UstXrxY7777ro4ePaqf/OQn2r17t1544QXv/mQAAOAj4YI+5pSbm+va8vjjj0fmlJSUuMrKStfU1OQ2btzohg4dGnWOrl27utLSUnfkyBF3/Phxt3btWte/f/+4fnxL4uPXDAaDwWDEcsTi67fv378wJRAIqK6uTikpKZ6+X2bIpPH65n8vlcQtCgAA8Fosvn7zo3QBAIBZhAwAADCLkGnhHD/qBgAAdDCEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCBkAAGAWIQMAAMwiZAAAgFmEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCJmWuP01AACmEDIAAMAsQgYAAJhFyAAAALMIGQAAYBYhAwAAzCJkAACAWYQMAAAwi5ABAABmETIAAMAsQgYAAJhFyAAAALMIGQAAYBYh04LjppEAAJhCyAAAALMIGQAAYBYhAwAAzCJkAACAWYQMAAAwi5ABAABmETIAAMAsQgYAAJhFyAAAALMIGQAAYBYhAwAAzCJkAACAWYRMFG4aCQCAJYQMAAAwi5ABAABmETIAAMAsQgYAAJhFyAAAALMIGQAAYBYhAwAAzCJkAACAWYQMAAAwi5ABAABmETIAAMAsQgYAAJjlech06dJF9913n/bv36/Gxka98cYbuvvuu+Xz+aLmlZSUKBgMqrGxURs2bFB2drbXSwEAAJ2c5yHzve99T7feeqsKCgp01VVXqbCwUHfeeaduv/32yJzCwkLNnz9fBQUFysnJUXV1tdatW6fk5GSvl3NBHDe/BgDAFM9DZsKECXrmmWf0/PPP66233tLvf/97lZWVacyYMZE58+bN0wMPPKA1a9Zo7969mjNnjnr06KHZs2d7vRwAANCJeR4yf/vb3zR16lQNGjRIkjR8+HBNmjRJzz//vCQpKytLGRkZKisri/yecDisTZs2aeLEiV4vBwAAdGKJXp/wxz/+sXr16qV//OMfev/999WlSxfddddd+s1vfiNJSk9PlySFQqGo3xcKhTRw4MA2z+n3+9W1a9fI40Ag4PWyAQCAQZ5fkZk1a5ZuvPFGzZ49W6NGjdKcOXP03e9+VzfffHPUPHfaG1J8Pl+rY6cUFRWprq4uMoLBoNfLBgAABnkeMg899JB+9KMf6be//a327NmjX//611q6dKmKiookSdXV1ZI+vDJzSmpqaqurNKcsWrRIKSkpkZGZmen1sgEAgEGeh0yPHj30wQcfRB17//33lZDwr5c6cOCAqqqqlJeXF3k+KSlJubm5Ki8vb/Oc4XBY9fX1UQMAAMDz98g8++yzuuuuu3Tw4EHt3btXV199tebPn69f/epXkTnLli1TcXGxKioqVFFRoeLiYjU2Nmr16tVeLwcAAHRinofM7bffrvvuu08rVqxQamqqKisr9dOf/lT33ntvZM6DDz6o7t27a8WKFerTp4+2bt2qadOmqaGhwevlAACATswnydyPgQsEAqqrq1NKSoqn32YaND5Ht/68VJJ0xycmeHZeAAAQm6/f3GsJAACYRcgAAACzCBkAAGAWIdMSd40EAMAUQgYAAJhFyAAAALMIGQAAYBYhAwAAzCJkAACAWYQMAAAwi5ABAABmETIAAMAsQgYAAJhFyAAAALMIGQAAYBYhAwAAzCJkAACAWYRMC467XwMAYAohAwAAzCJkAACAWYQMAAAwi5ABAABmETIAAMAsQgYAAJhFyAAAALMIGQAAYBYhAwAAzCJkAACAWYQMAAAwi5ABAABmETItcNNIAABsIWQAAIBZhAwAADCLkAEAAGYRMgAAwCxCBgAAmEXIAAAAswgZAABgFiEDAADMImQAAIBZhEwLPp8v3ksAAAAXgJABAABmETIAAMAsQqYFbhoJAIAthAwAADCLkAEAAGYRMgAAwCxCBgAAmEXIAAAAswgZAABgFiEDAADMImQAAIBZhAwAADCLkAEAAGYRMgAAwCxCBgAAmBWTkOnXr59WrlypI0eO6Pjx43rllVc0atSoqDklJSUKBoNqbGzUhg0blJ2dHYulAACATszzkOndu7deeuklnThxQtdff72ys7N1xx13qLa2NjKnsLBQ8+fPV0FBgXJyclRdXa1169YpOTnZ6+VcGO5+DQCAOc7LsWjRIvfXv/71rHMqKytdYWFh5LHf73c1NTUuPz//vF4jEAg455wLBAKerv3yMVe7xbs3u8W7N3t6XgaDwWAwGLH5+u35FZkbbrhB27dv19NPP61QKKSXX35Zt9xyS+T5rKwsZWRkqKysLHIsHA5r06ZNmjhxYpvn9Pv9CgQCUQMAAMDzkPn4xz+uuXPnqqKiQp/+9Kf12GOPqbS0VDfddJMkKT09XZIUCoWifl8oFIo8d7qioiLV1dVFRjAY9HrZAADAIM9DJiEhQS+//LLuuusuvfrqq/rZz36mn//855o7d27UPHfa+1F8Pl+rY6csWrRIKSkpkZGZmen1sgEAgEGeh0xVVZX27dsXdey1117TZZddJkmqrq6WpFZXX1JTU1tdpTklHA6rvr4+agAAAHgeMi+99JKuvPLKqGODBw/WW2+9JUk6cOCAqqqqlJeXF3k+KSlJubm5Ki8v93o5AACgE0v0+oRLly5VeXm5ioqK9PTTT2vs2LHKz89Xfn5+ZM6yZctUXFysiooKVVRUqLi4WI2NjVq9erXXywEAAJ2c5x+v+uxnP+t27drlmpqa3L59+9wtt9zSak5JSYmrrKx0TU1NbuPGjW7o0KFx/fiWxMevGQwGg8GI5YjF12/fv39hSiAQUF1dnVJSUjx9v8zlY67WbY+vkCTd8YkJnp0XAADE5us391oCAABmETIAAMAsQgYAAJhFyLRg7s1CAAB8xBEyAADALEIGAACYRcgAAACzCBkAAGAWIQMAAMwiZAAAgFmEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCBkAAGAWIdOS47aRAABYQsgAAACzCBkAAGAWIQMAAMwiZAAAgFmEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCBkAAGAWIQMAAMwiZAAAgFmEDAAAMIuQacFx92sAAEwhZAAAgFmEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCBkAAGAWIQMAAMwiZAAAgFmEDAAAMIuQAQAAZhEyAADALEKmJW4aCQCAKYQMAAAwi5ABAABmETIAAMAsQgYAAJhFyAAAALMIGQAAYBYhAwAAzCJkAACAWYQMAAAwi5ABAABmETIAAMCsmIfMggUL5JzT0qVLo46XlJQoGAyqsbFRGzZsUHZ2dqyXAgAAOpmYhsyYMWOUn5+vnTt3Rh0vLCzU/PnzVVBQoJycHFVXV2vdunVKTk6O5XLOiXtGAgBgS8xCpmfPnlq1apW++c1vqqamJuq5efPm6YEHHtCaNWu0d+9ezZkzRz169NDs2bNjtRwAANAJxSxkli9frueee07r16+POp6VlaWMjAyVlZVFjoXDYW3atEkTJ05s81x+v1+BQCBqAAAAJMbipLNmzdKoUaOUk5PT6rn09HRJUigUijoeCoU0cODANs9XVFSke+65x/N1AgAA2zy/ItO/f389/PDDuvHGG9Xc3HzGee60N6T4fL5Wx05ZtGiRUlJSIiMzM9PTNQMAAJs8vyIzevRopaWlaceOHR++SGKirrnmGhUUFOjKK6+U9K8rM9XV1ZE5qampra7SnBIOhxUOh71eKgAAMM7zKzLr16/XsGHDNHLkyMj4+9//rlWrVmnkyJHav3+/qqqqlJeXF/k9SUlJys3NVXl5udfLAQAAnZjnV2QaGhq0d+/eqGPHjx/Xu+++Gzm+bNkyFRcXq6KiQhUVFSouLlZjY6NWr17t9XIAAEAnFpM3+57Lgw8+qO7du2vFihXq06ePtm7dqmnTpqmhoSEeywEAAEa1S8hce+21rY4tXLhQCxcubI+XBwAAnRT3WgIAAGYRMgAAwCxCBgAAmEXIAAAAswiZlrj9NQAAphAyAADALEIGAACYRcgAAACzCBkAAGAWIQMAAMwiZAAAgFmEDAAAMIuQOQOfzxfvJQAAgHMgZAAAgFmEDAAAMIuQAQAAZhEyAADALEKmBSduGgkAgCWEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCBkAAGAWIQMAAMwiZFpoqquP9xIAAMAFIGRaaDha8+ED7n4NAECHR8gAAACzCBkAAGAWIQMAAMwiZAAAgFmETAuOm18DAGAKIQMAAMwiZAAAgFmEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCBkAAGAWIXMGPu5+DQBAh0fIAAAAswgZAABgFiEDAADMImSicNdIAAAsIWQAAIBZhAwAADCLkAEAAGYRMgAAwCxCBgAAmEXIAAAAswgZAABgFiEDAADMImQAAIBZhMyZcPNrAAA6PM9DZsGCBdq2bZvq6uoUCoW0Zs0aDR48uNW8kpISBYNBNTY2asOGDcrOzvZ6KQAAoJPzPGRyc3O1fPlyjR8/Xnl5eUpMTFRZWZl69OgRmVNYWKj58+eroKBAOTk5qq6u1rp165ScnOz1cgAAQCfnYjn69u3rnHNu8uTJkWOVlZWusLAw8tjv97uamhqXn59/XucMBALOOecCgYCna+0WSHaLd292i3dvdgmJXWK6LwwGg8FgfNRGLL5+x/w9Mr169ZIkHT16VJKUlZWljIwMlZWVReaEw2Ft2rRJEydOjPVyAABAJ5IY6xdYsmSJXnzxRe3du1eSlJ6eLkkKhUJR80KhkAYOHNjmOfx+v7p27Rp5HAgEYrRaAABgSUyvyDz66KMaPny4vvKVr7R6zjkX9djn87U6dkpRUZHq6uoiIxgMxmS9AADAlpiFTGlpqW644QZde+21UeFRXV0t6cMrM6ekpqa2ukpzyqJFi5SSkhIZmZmZsVo2AAAwJCYh88gjj+gLX/iCpkyZojfffDPquQMHDqiqqkp5eXmRY0lJScrNzVV5eXmb5wuHw6qvr48aAAAAnr9HZvny5Zo9e7amT5+u+vp6paWlSZKOHTum9957T5K0bNkyFRcXq6KiQhUVFSouLlZjY6NWr17t9XIAAEAn5nnI3HbbbZKkTZs2RR3/6le/qieeeEKS9OCDD6p79+5asWKF+vTpo61bt2ratGlqaGjwejkAAKAT8zxkfL7z+9n+Cxcu1MKFC71+eQAA8BHCvZYAAIBZhAwAADCLkDkDH7e/BgCgwyNkAACAWYQMAAAwi5Bp6Qy3SAAAAB0TIQMAAMwiZAAAgFmEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCBkAAGAWIQMAAMwiZAAAgFmEzJn4uPs1AAAdHSEDAADMImRacNw0EgAAUwgZAABgFiEDAADMImQAAIBZhAwAADCLkAEAAGYRMgAAwCxCBgAAmEXIAAAAswgZAABgFiEDAADMImQAAIBZhMwZ+Lj7NQAAHR4h0xL3jAQAwBRCBgAAmEXIAAAAswgZAABgFiEDAADMImQAAIBZhAwAADCLkAEAAGYRMgAAwCxCBgAAmEXIAAAAswgZAABgFiEDAADMImRacO7Du0b6Erj7NQAAHR0h08L7J09Gft0lMTGOKwEAAOeDkGnhA0IGAABTCJkWnHP64P33JUkJhAwAAB0eIXOaU99e4ooMAAAdHyFzmg+vyHSJ80oAAMC5EDKn6dqjhyQpcMn/ivNKAADAuRAyZ5Az/bPxXgIAADgHQuYM/vG3LfFeAgAAOAdC5jS1oXckSQldYrs1w6dN0bzf/EqX9M+M6esAANCZETKn6Z2WKkm66aH7Yvo6cxY/oAFDr9LMe4tj+joAAHRmhMwZJHTpooInHlP3lBSlpF6qYVNy5fN5f9uCHikBz88JALDls/PmavL/mRnvZZgU15CZO3eu9u/fr6amJm3fvl2TJk2K53JayRo1Qve/9Bd975mn9LWHf6SxMz4nSerZp7cS/f6oub6EBGUMvvycH9vukpQUueojSf2uHKRrbvryGef7Es78V9Qr7VKN/6/pSvT7lej36/IxV0e9fqLfr8whg8+6ngtxai0d6WfsZF09XJlDBiuxa9d2f+0hkyco5dK+7f660r9C+4J/jwc/UuCyT2Qr7fKs//g8HYG/e3ddOXHcRe1L95SAuiX3jMGqOq/+2UP0v7/yX/FeRoeU9vGPaco3btbnF/zfeC/FJJ8kd85ZMTBz5kytXLlSt912m1566SV961vf0i233KLs7GwdOnTorL83EAiorq5OKSkpqq+v93Rdi3dv9vR8AM7s8FuHdOnAAWed01TfoO6B5HZa0UfP+ydOqkuSt/84qTvyrlL6XuLpOS/G8Zpa9ezTW5J04r1m1VaHdOnHLvuPznnglV3Kunr4Wed88MEHSjjLP0Ivxournm73Kzb3fuoGHQsd9vScsfj6HbeQ2bJli15++WXddtttkWP79u3TH/7wBxUXn/19I7EMmR9u/X/q2qO7p+cEAMCiO6+epA9Ovu/Z+WLx9Tsu31pKSkrS6NGjVVZWFnW8rKxMEydObDXf7/crEAhEjVgpHjclZucGAMASLyMmVuISMn379lViYqJCoVDU8VAopPT09Fbzi4qKVFdXFxnBYDCm67vjExNUPG6qnl38aExfBwCAjuqRm74V7yWcl7i+a9O56O9q+Xy+VsckadGiRVqyZEnkcSAQiHnMNDc2auP/rNLG/1kV09cBAAAXLy4hc+TIEZ08ebLV1ZfU1NRWV2kkKRwOKxwOt9fyAACAEXH51tKJEye0Y8cO5eXlRR3Py8tTeXl5PJYEAAAMitu3lpYsWaKVK1dq+/bt2rx5s/Lz83XZZZfpsccei9eSAACAMXELmaefflqXXHKJfvCDHygjI0N79uzRZz7zGR08eDBeSwIAAMbE7efI/Cdi+XNkAABAbHSanyMDAADgBUIGAACYRcgAAACzCBkAAGAWIQMAAMwiZAAAgFmEDAAAMIuQAQAAZsX17tf/qUAgEO8lAACA8xSLr9smQ+bURgSDwTivBAAAXKhAIODZT/Y1eYsCSerXr19Mbk8QCAQUDAaVmZnJ7Q/aAfvd/tjz9sV+ty/2u/1d6J4HAgFVVlZ69vomr8hI8nQT2lJfX8//CNoR+93+2PP2xX63L/a7/Z3vnnv998KbfQEAgFmEDAAAMIuQOU1zc7PuueceNTc3x3spHwnsd/tjz9sX+92+2O/2F+89N/tmXwAAAK7IAAAAswgZAABgFiEDAADMImQAAIBZhEwLc+fO1f79+9XU1KTt27dr0qRJ8V5ShzR58mStXbtWwWBQzjlNnz691ZySkhIFg0E1NjZqw4YNys7Ojnre7/ertLRUhw8fVkNDg5555hllZmZGzendu7eefPJJ1dbWqra2Vk8++aR69eoVNWfAgAFau3atGhoadPjwYT388MNKSkry/g8dJwsWLNC2bdtUV1enUCikNWvWaPDgwa3msd/eufXWW7Vz504dO3ZMx44dU3l5ua677rqoOex37CxYsEDOOS1dujTqOHvujZKSEjnnokZVVVWrOdb22jHkZs6c6Zqbm903vvENN2TIELd06VJXX1/vBgwYEPe1dbRx3XXXufvuu8/NmDHDOefc9OnTo54vLCx0x44dczNmzHBDhw51Tz31lAsGgy45OTkyZ8WKFe7QoUNu6tSpbuTIkW79+vXulVdecQkJCZE5zz//vNu1a5cbP368Gz9+vNu1a5dbu3Zt5PmEhAS3a9cut379ejdy5Eg3depU9/bbb7vS0tK475FX409/+pObM2eOy87OdsOHD3fPPvuse/PNN12PHj3Y7xiNz33uc+766693gwYNcoMGDXL333+/a25udtnZ2ex3jMeYMWPc/v373auvvuqWLl3Kf+MxGCUlJW737t0uLS0tMvr27Wt9r+O/sR1hbNmyxa1YsSLq2L59+9wPf/jDuK+tI4+2QqaystIVFhZGHvv9fldTU+Py8/OdJJeSkuKam5vdzJkzI3MyMjLcyZMn3bRp05wkN2TIEOecc2PHjo3MGTdunHPOucGDBzvpX0F18uRJl5GREZkza9Ys19TU5AKBQNz3Jhajb9++zjnnJk+ezH6343j33Xfd17/+dfY7hqNnz57u9ddfd1OnTnUbNmyIChn23LtRUlLiXnnllTM+b3Gv+daSpKSkJI0ePVplZWVRx8vKyjRx4sQ4rcqmrKwsZWRkRO1lOBzWpk2bIns5evRo+f3+qDlVVVXas2dPZM6ECRNUW1urbdu2ReZs3bpVtbW1UXP27NkTdVn0L3/5i7p166bRo0fH9M8ZL6cuzR49elQS+x1rCQkJmjVrlnr27KnNmzez3zG0fPlyPffcc1q/fn3Ucfbce4MGDVIwGNT+/fv11FNPKSsrS5LdvTZ700gv9e3bV4mJiQqFQlHHQ6GQ0tPT47Qqm07tV1t7OXDgwMic5uZm1dbWtppz6venp6frnXfeaXX+d955J2rO6a9TW1ur5ubmTvv3tmTJEr344ovau3evJPY7VoYNG6bNmzerW7duamho0IwZM/Taa69pwoQJkthvr82aNUujRo1STk5Oq+f4b9xbW7du1c0336x//vOfSktL0/e//32Vl5dr6NChZveakGnBORf12OfztTqG83Mxe3n6nLbmX8yczuLRRx/V8OHD23wTOvvtrddff10jR45U79699cUvflFPPPGEcnNzI8+z397p37+/Hn74YU2bNu2sP+KePffGn//858iv9+zZo82bN+uNN97QnDlztGXLFkn29ppvLUk6cuSITp482aoCU1NTWxUjzq66ulqSzrqX1dXV6tq1q3r37n3WOWlpaa3Of+mll0bNOf11evfuLb/f3+n+3kpLS3XDDTfo2muvVTAYjBxnv2PjxIkTeuONN7Rjxw4VFxdr586d+s53vsN+x8Do0aOVlpamHTt26MSJEzpx4oQ++clP6tvf/rZOnDgR+bOy57HR2Nio3bt3a9CgQab/+477m486wtiyZYtbvnx51LG9e/fyZt9zjDO92ffOO++MPE5KSmrzzWJf+tKXInPS09PbfLNYTk5OZM7YsWPbfLNYenp6ZM7MmTM71RvzJLlHHnnEvf322+6KK65o83n2O/bjhRdecI8//jj7HYORnJzshg4dGjW2bdvmnnzySTd06FD2PMbD7/e7Q4cOubvvvtvyXsd/IzvCOPXx66997WtuyJAhbsmSJa6+vt5ddtllcV9bRxs9e/Z0I0aMcCNGjHDOOTdv3jw3YsSIyEfVCwsLXU1Njfv85z/vhg4d6latWtXmx/cOHjzopkyZ4kaOHOleeOGFNj++9+qrr7px48a5cePGuZ07d7b58b1169a5kSNHuilTpriDBw92qo9KLl++3NXU1Lhrrrkm6uOS3bp1i8xhv70dDzzwgJs0aZIbOHCgGzZsmLv//vvdyZMn3ac+9Sn2u53G6Z9aYs+9Gw899JC75ppr3Mc+9jE3duxYt3btWnfs2LHI1zqjex3/je0oY+7cue7AgQPuvffec9u3b4/6iCvjw5Gbm+vacupfrNK/PuJXWVnpmpqa3MaNGyP/sjo1unbt6kpLS92RI0fc8ePH3dq1a13//v2j5vTp08etXLnSHTt2zB07dsytXLnS9erVK2rOgAED3LPPPuuOHz/ujhw54kpLS53f74/7Hnk1zmTOnDlR89hv78YvfvGLyP8PhEIht27dukjEsN/tM04PGfbcu3Hq58I0Nze7t99+2/3ud79zV111lem99v37FwAAAObwZl8AAGAWIQMAAMwiZAAAgFmEDAAAMIuQAQAAZhEyAADALEIGAACYRcgAAACzCBkAAGAWIQMAAMwiZAAAgFmEDAAAMOv/A1M4mBpD3zZtAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = dgpt2 # don't forget to change tokenizer name and optimizer too\n",
    "model.train()\n",
    "\n",
    "# Move the model to the device (GPU if available)\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "model.to(device)\n",
    "model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "dataset = TextDataset(text_data, tokenizer, None, do_moe_format=False)\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "\n",
    "# Set up the optimizer and learning rate scheduler\n",
    "optimizer = AdamW(model.parameters(), lr=5e-4)\n",
    "# num_training_steps = len(dataloader) * 1  # Number of epochs\n",
    "# lr_scheduler = LinearLR(optimizer, total_iters=num_training_steps)\n",
    "\n",
    "losses = []\n",
    "for epoch in range(1):  # Train for 3 epochs\n",
    "    for batch in tqdm(dataloader):\n",
    "        # print(batch)\n",
    "        input_ids, attention_mask = batch\n",
    "        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)\n",
    "\n",
    "        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)\n",
    "        loss = outputs.loss\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        losses.append(loss.item())\n",
    "        \n",
    "print(tokenizer.batch_decode(model.generate()))\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Embedding(50259, 768)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "MOEModelForCausalLM(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50259, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-5): 6 x GPT2Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPT2Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MultiLayer(\n",
       "          (layers): ModuleList(\n",
       "            (0-14): 15 x GPT2MLP(\n",
       "              (c_fc): Conv1D()\n",
       "              (c_proj): Conv1D()\n",
       "              (act): NewGELUActivation()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): MultiLayer(\n",
       "    (layers): ModuleList(\n",
       "      (0-14): 15 x Linear(in_features=768, out_features=50259, bias=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "MOEModelForCausalLM(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50259, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-5): 6 x GPT2Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPT2Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MultiLayer(\n",
       "          (layers): ModuleList(\n",
       "            (0-14): 15 x GPT2MLP(\n",
       "              (c_fc): Conv1D()\n",
       "              (c_proj): Conv1D()\n",
       "              (act): NewGELUActivation()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): MultiLayer(\n",
       "    (layers): ModuleList(\n",
       "      (0-14): 15 x Linear(in_features=768, out_features=50259, bias=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_experts = 15\n",
    "dgpt2.resize_token_embeddings(len(tokenizer))\n",
    "dgpt2copy = MOEModelForCausalLM(dgpt2, num_experts=num_experts, multihead=True)\n",
    "model = dgpt2copy # don't forget to change tokenizer name and optimizer too\n",
    "ckpt_path = '/hdd3/sonia/be_great/ckpts/adult-tiny/plain/moe/8-19(18:37)/3000.pt'\n",
    "model.load_state_dict(torch.load(ckpt_path))\n",
    "model.cuda()\n",
    "model.eval()\n",
    "column_names_tokens = tokenizer(list(data.columns)).input_ids\n",
    "model.set_generation_mode(token_heads=list(range(15)), column_names_tokens=column_names_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'age is 39.<EOS>workclass is Private.<EOS>fnlwgt.<EOS>education is Some-college.<EOS>education-num is 14.<EOS>marital-status is Never-married.<EOS>occupation is Craft-repair.<EOS>relationship is Husband.<EOS>race is White.<EOS>sex is Male.<EOS>capital-gain-capital is 0.<EOS>capital-loss-capital is 0.<EOS>hours-per-week is 40.<EOS>native-country is United-States.<EOS>income is <=50K.<EOS>'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "toks=model.generate(do_sample=True, num_beams=1, max_length=200, \n",
    "                                pad_token_id=tokenizer.eos_token_id)[...,1:] # remove BOS token\n",
    "tokenizer.batch_decode(toks)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
      "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'<|endoftext|>age is 39.<EOS>workclass is Private.<EOS>fnlwgt is 166824.<EOS>education is HS-grad.<EOS>education-num is 9.<EOS>marital-status is Married-civ-spouse.<EOS>occupation is Craft-repair.<EOS>relationship is Husband.<EOS>race is White.<EOS>sex is Male.<EOS>capital-gain is 0.<EOS>capital-loss is 0.<EOS>hours-per-week is 40.<EOS>native-country is United-States.<EOS>income is >50K.<EOS>income is <=50K.<EOS>income is <=50K.<EOS>income is'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.batch_decode(model.generate(max_length=130))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>education-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>sex</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>State-gov</td>\n",
       "      <td>77516</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>2174</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>50</td>\n",
       "      <td>Self-emp-not-inc</td>\n",
       "      <td>83311</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>215646</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>53</td>\n",
       "      <td>Private</td>\n",
       "      <td>234721</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Handlers-cleaners</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28</td>\n",
       "      <td>Private</td>\n",
       "      <td>338409</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Wife</td>\n",
       "      <td>Black</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>Cuba</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48837</th>\n",
       "      <td>39</td>\n",
       "      <td>Private</td>\n",
       "      <td>215419</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Not-in-family</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>36</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48838</th>\n",
       "      <td>64</td>\n",
       "      <td>?</td>\n",
       "      <td>321403</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Widowed</td>\n",
       "      <td>?</td>\n",
       "      <td>Other-relative</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48839</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>374983</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Prof-specialty</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48840</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>83891</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Divorced</td>\n",
       "      <td>Adm-clerical</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Asian-Pac-Islander</td>\n",
       "      <td>Male</td>\n",
       "      <td>5455</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48841</th>\n",
       "      <td>35</td>\n",
       "      <td>Self-emp-inc</td>\n",
       "      <td>182148</td>\n",
       "      <td>Bachelors</td>\n",
       "      <td>13</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Exec-managerial</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>60</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>48842 rows × 15 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       age          workclass  fnlwgt   education  education-num  \\\n",
       "0       39          State-gov   77516   Bachelors             13   \n",
       "1       50   Self-emp-not-inc   83311   Bachelors             13   \n",
       "2       38            Private  215646     HS-grad              9   \n",
       "3       53            Private  234721        11th              7   \n",
       "4       28            Private  338409   Bachelors             13   \n",
       "...    ...                ...     ...         ...            ...   \n",
       "48837   39            Private  215419   Bachelors             13   \n",
       "48838   64                  ?  321403     HS-grad              9   \n",
       "48839   38            Private  374983   Bachelors             13   \n",
       "48840   44            Private   83891   Bachelors             13   \n",
       "48841   35       Self-emp-inc  182148   Bachelors             13   \n",
       "\n",
       "            marital-status          occupation     relationship  \\\n",
       "0            Never-married        Adm-clerical    Not-in-family   \n",
       "1       Married-civ-spouse     Exec-managerial          Husband   \n",
       "2                 Divorced   Handlers-cleaners    Not-in-family   \n",
       "3       Married-civ-spouse   Handlers-cleaners          Husband   \n",
       "4       Married-civ-spouse      Prof-specialty             Wife   \n",
       "...                    ...                 ...              ...   \n",
       "48837             Divorced      Prof-specialty    Not-in-family   \n",
       "48838              Widowed                   ?   Other-relative   \n",
       "48839   Married-civ-spouse      Prof-specialty          Husband   \n",
       "48840             Divorced        Adm-clerical        Own-child   \n",
       "48841   Married-civ-spouse     Exec-managerial          Husband   \n",
       "\n",
       "                      race      sex  capital-gain  capital-loss  \\\n",
       "0                    White     Male          2174             0   \n",
       "1                    White     Male             0             0   \n",
       "2                    White     Male             0             0   \n",
       "3                    Black     Male             0             0   \n",
       "4                    Black   Female             0             0   \n",
       "...                    ...      ...           ...           ...   \n",
       "48837                White   Female             0             0   \n",
       "48838                Black     Male             0             0   \n",
       "48839                White     Male             0             0   \n",
       "48840   Asian-Pac-Islander     Male          5455             0   \n",
       "48841                White     Male             0             0   \n",
       "\n",
       "       hours-per-week  native-country  income  \n",
       "0                  40   United-States   <=50K  \n",
       "1                  13   United-States   <=50K  \n",
       "2                  40   United-States   <=50K  \n",
       "3                  40   United-States   <=50K  \n",
       "4                  40            Cuba   <=50K  \n",
       "...               ...             ...     ...  \n",
       "48837              36   United-States   <=50K  \n",
       "48838              40   United-States   <=50K  \n",
       "48839              50   United-States   <=50K  \n",
       "48840              40   United-States   <=50K  \n",
       "48841              60   United-States    >50K  \n",
       "\n",
       "[48842 rows x 15 columns]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.makedirs('./ckpts/dgpt2/adult-allcol/', exist_ok=True)\n",
    "torch.save(model.state_dict(), f'./ckpts/dgpt2/adult-allcol/{len(losses)}.pt')\n",
    "plt.plot(losses)\n",
    "plt.savefig('./ckpts/dgpt2/adult-allcol/losses.png')\n",
    "\n",
    "samples = []\n",
    "for i in tqdm(range(10000)):\n",
    "    samples.append(tokenizer.batch_decode(model.generate(do_sample=True, num_beams=1, max_length=140))[0])\n",
    "    if i %1000 == 0:\n",
    "        with open('./ckpts/dgpt2/adult-allcol/samples.txt', 'a') as f:\n",
    "            f.write('\\n'.join(samples))\n",
    "        samples = []\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hdd2/sonia/miniconda3/envs/great/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "MOEModelForCausalLM(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-5): 6 x GPT2Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPT2Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): MOEMLP(\n",
       "          (mlps): ModuleList(\n",
       "            (0-14): 15 x GPT2MLP(\n",
       "              (c_fc): Conv1D()\n",
       "              (c_proj): Conv1D()\n",
       "              (act): NewGELUActivation()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "Embedding(50259, 768)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dgpt2 = transformers.AutoModelForCausalLM.from_pretrained('distilgpt2')\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"distilgpt2\")\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "special_tokens_dict = {\"bos_token\": \"<BOS>\", 'eos_token': '<EOS>'}\n",
    "num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)\n",
    "num_experts = 15\n",
    "dgpt2copy = MOEModelForCausalLM(dgpt2, num_experts=num_experts)\n",
    "model = dgpt2copy # don't forget to change tokenizer name and optimizer too\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "model.to(device)\n",
    "model.resize_token_embeddings(len(tokenizer))\n",
    "model.load_state_dict(torch.load('/hdd3/sonia/be_great/ckpts/moe/dgpt2/adult-allcol/jul21/48000.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'<|endoftext|>age is 50.<EOS>workclass is Private.<EOS>fnlwgt is 266061.<EOS>education is Assoc-voc.<EOS>education-num is 11.<EOS>marital-status is Married-civ-spouse.<EOS>occupation is Adm-clerical.<EOS>relationship is Not-in-family.<EOS>race is White.<EOS>sex is Male.<EOS>capital-gain is 0.<EOS>capital-loss is 0.<EOS>hours-per-week is 40.<EOS>native-country is United-States.<EOS>income is <=50K.<EOS>income is <=50K.<EOS>native-country is United-States.<EOS>income is <=50'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.batch_decode(model.generate(do_sample=True, num_beams=1, max_length=140))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = []\n",
    "for i in tqdm(range(10000)):\n",
    "    samples.append(tokenizer.batch_decode(model.generate(do_sample=True, num_beams=1, max_length=140))[0])\n",
    "    \n",
    "with open('/hdd3/sonia/be_great/ckpts/moe/dgpt2/adult-allcol/jul21/samples.txt', 'w') as f:\n",
    "    f.write('\\n'.join(samples))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
