{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc1cf0bb-0a1a-4525-8a33-0fe08cd6cee8",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "import struct\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "class MnistParser:\n",
    "\n",
    "    def load_image(self, file_path):\n",
    "\n",
    "\n",
    "        binary = open(file_path,'rb').read()\n",
    "\n",
    "        fmt_head = '>iiii'\n",
    "        offset = 0\n",
    "\n",
    "        magic_number,images_number,rows_number,columns_number = struct.unpack_from(fmt_head,binary,offset)\n",
    "\n",
    "        image_size = rows_number * columns_number\n",
    "        fmt_data = '>'+str(image_size)+'B'\n",
    "        offset = offset + struct.calcsize(fmt_head)\n",
    "\n",
    "        images = np.empty((images_number,rows_number,columns_number))\n",
    "        for i in range(images_number):\n",
    "            images[i] = np.array(struct.unpack_from(fmt_data, binary, offset)).reshape((rows_number, columns_number))\n",
    "            offset = offset + struct.calcsize(fmt_data)\n",
    "\n",
    "        return images_number,rows_number,columns_number,images\n",
    "\n",
    "\n",
    "\n",
    "    def load_labels(self, file_path):\n",
    "\n",
    "        binary = open(file_path,'rb').read()\n",
    "\n",
    "        fmt_head = '>ii'\n",
    "        offset = 0\n",
    "\n",
    "        magic_number,items_number = struct.unpack_from(fmt_head,binary,offset)\n",
    "\n",
    "        fmt_data = '>B'\n",
    "        offset = offset + struct.calcsize(fmt_head)\n",
    "\n",
    "        labels = np.empty((items_number))\n",
    "        for i in range(items_number):\n",
    "            labels[i] = struct.unpack_from(fmt_data, binary, offset)[0]\n",
    "            offset = offset + struct.calcsize(fmt_data)\n",
    "\n",
    "        return items_number,labels\n",
    "\n",
    "    def visualaztion(self, images, labels, path):\n",
    "        d = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}\n",
    "        for i in range(images.__len__()):\n",
    "            im = Image.fromarray(np.uint8(images[i]))\n",
    "            im.save(path + \"%d_%d.png\"%(labels[i], d[labels[i]]))\n",
    "            d[labels[i]] += 1\n",
    "\n",
    "\n",
    "def change_and_save():\n",
    "    mnist =  MnistParser()\n",
    "\n",
    "    trainImageFile = 'file/mnist/train-images-idx3-ubyte'\n",
    "    _, _, _, images = mnist.load_image(trainImageFile)\n",
    "    trainLabelFile = 'file/mnist/train-labels-idx1-ubyte'\n",
    "    _, labels = mnist.load_labels(trainLabelFile)\n",
    "    mnist.visualaztion(images, labels, \"file/mnist/train/\")\n",
    "\n",
    "    testImageFile = 'file/mnist/t10k-images-idx3-ubyte'\n",
    "    _, _, _, images = mnist.load_image(testImageFile)\n",
    "    testLabelFile = 'file/mnist/t10k-labels-idx1-ubyte'\n",
    "    _, labels = mnist.load_labels(testLabelFile)\n",
    "    mnist.visualaztion(images, labels, \"file/mnist/test/\")\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    change_and_save()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
