{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "875ff0af-ccae-480e-8ea0-7abce93ce6ac",
   "metadata": {},
   "source": [
    "From https://github.com/pilancilab/Randomized-Geometric-Algebra-Methods-for-Convex-Neural-Networks/blob/main/Readme.md\n",
    "Used to download cola"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f66347f8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "dfb5c2f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading readme: 100%|██████████████████| 7.62k/7.62k [00:00<00:00, 8.18MB/s]\n",
      "Downloading data: 100%|████████████████████| 14.5M/14.5M [00:01<00:00, 12.4MB/s]\n",
      "Downloading data: 100%|████████████████████| 1.82M/1.82M [00:00<00:00, 5.40MB/s]\n",
      "Generating train split: 100%|██| 87599/87599 [00:00<00:00, 916336.01 examples/s]\n",
      "Generating validation split: 100%|█| 10570/10570 [00:00<00:00, 939215.59 example\n"
     ]
    }
   ],
   "source": [
    "#didn't work - EZ\n",
    "#dataset = load_dataset(\"parquet\", data_files={'train': 'data/glue_cola_train.parquet', 'test': 'data/glue_cola_val.parquet'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "dbea8fcc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading readme: 100%|██████████████████| 35.3k/35.3k [00:00<00:00, 9.40MB/s]\n",
      "Downloading data: 100%|██████████████████████| 251k/251k [00:00<00:00, 1.19MB/s]\n",
      "Downloading data: 100%|█████████████████████| 37.6k/37.6k [00:00<00:00, 201kB/s]\n",
      "Downloading data: 100%|█████████████████████| 37.7k/37.7k [00:00<00:00, 192kB/s]\n",
      "Generating train split: 100%|████| 8551/8551 [00:00<00:00, 887144.89 examples/s]\n",
      "Generating validation split: 100%|█| 1043/1043 [00:00<00:00, 296305.82 examples/\n",
      "Generating test split: 100%|█████| 1063/1063 [00:00<00:00, 398440.14 examples/s]\n"
     ]
    }
   ],
   "source": [
    "# cola\n",
    "dataset = load_dataset('glue','cola')\n",
    "df_train = dataset['train'].to_pandas()\n",
    "df_val = dataset['validation'].to_pandas()\n",
    "df = pd.concat([df_train,df_val],ignore_index=True)\n",
    "df.to_csv('data/glue_cola_raw.csv',index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "20fa9e2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading data: 100%|████████████████████| 33.6M/33.6M [00:02<00:00, 11.3MB/s]\n",
      "Downloading data: 100%|████████████████████| 3.73M/3.73M [00:00<00:00, 8.65MB/s]\n",
      "Downloading data: 100%|████████████████████| 36.7M/36.7M [00:02<00:00, 14.4MB/s]\n",
      "Generating train split: 100%|█| 363846/363846 [00:00<00:00, 2943567.38 examples/\n",
      "Generating validation split: 100%|█| 40430/40430 [00:00<00:00, 3481189.66 exampl\n",
      "Generating test split: 100%|█| 390965/390965 [00:00<00:00, 3985868.27 examples/s\n"
     ]
    }
   ],
   "source": [
    "# qqp\n",
    "dataset = load_dataset('glue','qqp')\n",
    "df_train = dataset['train'].to_pandas()\n",
    "df_val = dataset['validation'].to_pandas()\n",
    "df = pd.concat([df_train,df_val],ignore_index=True)\n",
    "df.to_csv('data/glue_qqp_raw.csv',index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4ec3a37b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentence</th>\n",
       "      <th>label</th>\n",
       "      <th>idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Our friends won't buy this analysis, let alone...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>One more pseudo generalization and I'm giving up.</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>One more pseudo generalization or I'm giving up.</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>The more we study verbs, the crazier they get.</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Day by day the facts are getting murkier.</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9589</th>\n",
       "      <td>John considers Bill silly.</td>\n",
       "      <td>1</td>\n",
       "      <td>1038</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9590</th>\n",
       "      <td>John considers Bill to be silly.</td>\n",
       "      <td>1</td>\n",
       "      <td>1039</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9591</th>\n",
       "      <td>John bought a dog for himself to play with.</td>\n",
       "      <td>0</td>\n",
       "      <td>1040</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9592</th>\n",
       "      <td>John arranged for himself to get the prize.</td>\n",
       "      <td>1</td>\n",
       "      <td>1041</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9593</th>\n",
       "      <td>John talked to Bill about himself.</td>\n",
       "      <td>1</td>\n",
       "      <td>1042</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>9594 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               sentence  label   idx\n",
       "0     Our friends won't buy this analysis, let alone...      1     0\n",
       "1     One more pseudo generalization and I'm giving up.      1     1\n",
       "2      One more pseudo generalization or I'm giving up.      1     2\n",
       "3        The more we study verbs, the crazier they get.      1     3\n",
       "4             Day by day the facts are getting murkier.      1     4\n",
       "...                                                 ...    ...   ...\n",
       "9589                         John considers Bill silly.      1  1038\n",
       "9590                   John considers Bill to be silly.      1  1039\n",
       "9591        John bought a dog for himself to play with.      0  1040\n",
       "9592        John arranged for himself to get the prize.      1  1041\n",
       "9593                 John talked to Bill about himself.      1  1042\n",
       "\n",
       "[9594 rows x 3 columns]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "7a0687bf",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'question1'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/pandas/core/indexes/base.py:3805\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   3804\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3805\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasted_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3806\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
      "File \u001b[0;32mindex.pyx:167\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mindex.pyx:196\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7081\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7089\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;31mKeyError\u001b[0m: 'question1'",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[21], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m df[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mreview\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mQuestion 1: \u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[43mdf\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mquestion1\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m Question 2: \u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39mdf[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mquestion2\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/pandas/core/frame.py:4102\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   4100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mnlevels \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m   4101\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 4102\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   4103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m   4104\u001b[0m     indexer \u001b[38;5;241m=\u001b[39m [indexer]\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/pandas/core/indexes/base.py:3812\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   3807\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(casted_key, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m   3808\u001b[0m         \u001b[38;5;28misinstance\u001b[39m(casted_key, abc\u001b[38;5;241m.\u001b[39mIterable)\n\u001b[1;32m   3809\u001b[0m         \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(x, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m casted_key)\n\u001b[1;32m   3810\u001b[0m     ):\n\u001b[1;32m   3811\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m InvalidIndexError(key)\n\u001b[0;32m-> 3812\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m   3813\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m   3814\u001b[0m     \u001b[38;5;66;03m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m   3815\u001b[0m     \u001b[38;5;66;03m#  InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m   3816\u001b[0m     \u001b[38;5;66;03m#  the TypeError.\u001b[39;00m\n\u001b[1;32m   3817\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n",
      "\u001b[0;31mKeyError\u001b[0m: 'question1'"
     ]
    }
   ],
   "source": [
    "df['review'] = 'Question 1: ' + df['question1']+ '\\n Question 2: ' +df['question2']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b0d2055f",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'review'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/pandas/core/indexes/base.py:3805\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   3804\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3805\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasted_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3806\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n",
      "File \u001b[0;32mindex.pyx:167\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mindex.pyx:196\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7081\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7089\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;31mKeyError\u001b[0m: 'review'",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[22], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdf\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mreview\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/pandas/core/frame.py:4102\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   4100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mnlevels \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m   4101\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 4102\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   4103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m   4104\u001b[0m     indexer \u001b[38;5;241m=\u001b[39m [indexer]\n",
      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/pandas/core/indexes/base.py:3812\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   3807\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(casted_key, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m   3808\u001b[0m         \u001b[38;5;28misinstance\u001b[39m(casted_key, abc\u001b[38;5;241m.\u001b[39mIterable)\n\u001b[1;32m   3809\u001b[0m         \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(x, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m casted_key)\n\u001b[1;32m   3810\u001b[0m     ):\n\u001b[1;32m   3811\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m InvalidIndexError(key)\n\u001b[0;32m-> 3812\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m   3813\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m   3814\u001b[0m     \u001b[38;5;66;03m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m   3815\u001b[0m     \u001b[38;5;66;03m#  InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m   3816\u001b[0m     \u001b[38;5;66;03m#  the TypeError.\u001b[39;00m\n\u001b[1;32m   3817\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n",
      "\u001b[0;31mKeyError\u001b[0m: 'review'"
     ]
    }
   ],
   "source": [
    "df['review']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ead2d9c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"cifar10\").with_format(\"np\")\n",
    "data = np.concatenate([dataset['train']['img'],dataset['test']['img']],axis=0)\n",
    "label = np.concatenate([dataset['train']['label'],dataset['test']['label']],axis=0)\n",
    "data = data.reshape([60000,-1])\n",
    "np.savez('./data/cifar10.npz',data=data,label=label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2abb9343",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b61fa3ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
    "                                        download=True, transform=transform)\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
    "                                       download=True, transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5109dc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(trainset, batch_size=len(trainset))\n",
    "train_data = next(iter(train_loader))[0].numpy()\n",
    "train_label = next(iter(train_loader))[1].numpy()\n",
    "\n",
    "test_loader = DataLoader(testset, batch_size=len(testset))\n",
    "test_data = next(iter(test_loader))[0].numpy()\n",
    "test_label = next(iter(test_loader))[1].numpy()\n",
    "np.savez('./data/cifar10_transformed.npz',train_data=train_data,train_label=train_label,test_data=test_data,test_label=test_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb3b97d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms_MNIST = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean = (0.1307,), std = (0.3081,))\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True,\n",
    "                                        download=True, transform=transforms_MNIST)\n",
    "testset = torchvision.datasets.MNIST(root='./data', train=False,\n",
    "                                       download=True, transform=transforms_MNIST)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7d04e23",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(trainset, batch_size=len(trainset))\n",
    "train_data = next(iter(train_loader))[0].numpy()\n",
    "train_label = next(iter(train_loader))[1].numpy()\n",
    "\n",
    "test_loader = DataLoader(testset, batch_size=len(testset))\n",
    "test_data = next(iter(test_loader))[0].numpy()\n",
    "test_label = next(iter(test_loader))[1].numpy()\n",
    "np.savez('./data/mnist_transformed.npz',train_data=train_data,train_label=train_label,test_data=test_data,test_label=test_label)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
