{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SET Classification: Training CNN card image embedder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from setGame import *\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers, Model\n",
    "from sklearn.model_selection import train_test_split\n",
    "sys.path.append('../..'); sys.path.append('../');\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "setgame = SetGame()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " input_2 (InputLayer)        [(None, 70, 50, 4)]       0         \n",
      "                                                                 \n",
      " conv2d_2 (Conv2D)           (None, 66, 46, 32)        3232      \n",
      "                                                                 \n",
      " max_pooling2d_2 (MaxPoolin  (None, 16, 11, 32)        0         \n",
      " g2D)                                                            \n",
      "                                                                 \n",
      " conv2d_3 (Conv2D)           (None, 12, 7, 32)         25632     \n",
      "                                                                 \n",
      " max_pooling2d_3 (MaxPoolin  (None, 3, 1, 32)          0         \n",
      " g2D)                                                            \n",
      "                                                                 \n",
      " flatten_1 (Flatten)         (None, 96)                0         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 64)                6208      \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 64)                4160      \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 12)                780       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 40012 (156.30 KB)\n",
      "Trainable params: 40012 (156.30 KB)\n",
      "Non-trainable params: 0 (0.00 Byte)\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers, Model, Sequential\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "ff_dim1=64\n",
    "ff_dim2=64\n",
    "\n",
    "img_input = layers.Input(shape=(70, 50, 4))\n",
    "x = layers.Conv2D(32, (5, 5), activation='relu')(img_input)\n",
    "x = layers.MaxPooling2D((4,4))(x)\n",
    "x = layers.Conv2D(32, (5, 5), activation='relu')(x)\n",
    "x = layers.MaxPooling2D((4,4))(x)\n",
    "x = layers.Flatten()(x)\n",
    "x = layers.Dense(ff_dim1, activation='relu')(x)\n",
    "x = layers.Dense(ff_dim2, activation='tanh')(x)\n",
    "outputs = layers.Dense(12, activation='sigmoid')(x)\n",
    "model = Model(inputs=img_input, outputs=outputs)\n",
    "embed = Model(model.input, model.layers[7].output)\n",
    "model.summary()\n",
    "opt = tf.keras.optimizers.Adam(learning_rate=0.001)\n",
    "model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['binary_accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_binary(attrs):\n",
    "    color = {'red':[1,0,0], 'green':[0,1,0], 'purple':[0,0,1]}\n",
    "    pattern = {'empty':[1,0,0], 'striped':[0,1,0], 'solid':[0,0,1]}\n",
    "    shape = {'diamond':[1,0,0], 'oval':[0,1,0], 'squiggle':[0,0,1]}\n",
    "    number = {'one':[1,0,0], 'two':[0,1,0], 'three':[0,0,1]}\n",
    "    binary_attrs = number[attrs[0]] + color[attrs[1]] + pattern[attrs[2]] + shape[attrs[3]]\n",
    "    return binary_attrs\n",
    "\n",
    "n = 1000\n",
    "X = np.empty((n, 70, 50, 4), dtype=np.float32)\n",
    "y = np.empty((n, 12), dtype=int)\n",
    "\n",
    "card_coord = [(i,j) for i in np.arange(9) for j in np.arange(9)]\n",
    "for i in np.arange(n):\n",
    "    c = np.random.choice(np.arange(81), size=1)[0]\n",
    "    (row, col) = card_coord[c]\n",
    "    attrs = setgame.attributes_of_card(row, col)\n",
    "    binary_attrs = convert_to_binary(attrs)\n",
    "    X[i] = setgame.image_of_card(row, col)\n",
    "    y[i] = binary_attrs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/50\n",
      "6/6 [==============================] - 1s 22ms/step - loss: 0.6629 - binary_accuracy: 0.6292\n",
      "Epoch 2/50\n",
      "6/6 [==============================] - 0s 21ms/step - loss: 0.6357 - binary_accuracy: 0.6667\n",
      "Epoch 3/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.6307 - binary_accuracy: 0.6667\n",
      "Epoch 4/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.6221 - binary_accuracy: 0.6667\n",
      "Epoch 5/50\n",
      "6/6 [==============================] - 0s 21ms/step - loss: 0.6112 - binary_accuracy: 0.6690\n",
      "Epoch 6/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.5945 - binary_accuracy: 0.6696\n",
      "Epoch 7/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.5729 - binary_accuracy: 0.6902\n",
      "Epoch 8/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.5500 - binary_accuracy: 0.7090\n",
      "Epoch 9/50\n",
      "6/6 [==============================] - 0s 21ms/step - loss: 0.5225 - binary_accuracy: 0.7331\n",
      "Epoch 10/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.4910 - binary_accuracy: 0.7577\n",
      "Epoch 11/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.4578 - binary_accuracy: 0.7866\n",
      "Epoch 12/50\n",
      "6/6 [==============================] - 0s 21ms/step - loss: 0.4248 - binary_accuracy: 0.8149\n",
      "Epoch 13/50\n",
      "6/6 [==============================] - 0s 19ms/step - loss: 0.3913 - binary_accuracy: 0.8439\n",
      "Epoch 14/50\n",
      "6/6 [==============================] - 0s 19ms/step - loss: 0.3595 - binary_accuracy: 0.8613\n",
      "Epoch 15/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.3323 - binary_accuracy: 0.8843\n",
      "Epoch 16/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.3083 - binary_accuracy: 0.8964\n",
      "Epoch 17/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.2880 - binary_accuracy: 0.9099\n",
      "Epoch 18/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.2670 - binary_accuracy: 0.9259\n",
      "Epoch 19/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.2484 - binary_accuracy: 0.9370\n",
      "Epoch 20/50\n",
      "6/6 [==============================] - 0s 21ms/step - loss: 0.2300 - binary_accuracy: 0.9461\n",
      "Epoch 21/50\n",
      "6/6 [==============================] - 0s 19ms/step - loss: 0.2122 - binary_accuracy: 0.9599\n",
      "Epoch 22/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.1971 - binary_accuracy: 0.9653\n",
      "Epoch 23/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.1836 - binary_accuracy: 0.9716\n",
      "Epoch 24/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.1705 - binary_accuracy: 0.9763\n",
      "Epoch 25/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.1592 - binary_accuracy: 0.9774\n",
      "Epoch 26/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.1483 - binary_accuracy: 0.9798\n",
      "Epoch 27/50\n",
      "6/6 [==============================] - 0s 19ms/step - loss: 0.1382 - binary_accuracy: 0.9869\n",
      "Epoch 28/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.1281 - binary_accuracy: 0.9914\n",
      "Epoch 29/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.1194 - binary_accuracy: 0.9924\n",
      "Epoch 30/50\n",
      "6/6 [==============================] - 0s 19ms/step - loss: 0.1121 - binary_accuracy: 0.9943\n",
      "Epoch 31/50\n",
      "6/6 [==============================] - 0s 21ms/step - loss: 0.1058 - binary_accuracy: 0.9961\n",
      "Epoch 32/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0995 - binary_accuracy: 0.9967\n",
      "Epoch 33/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0933 - binary_accuracy: 0.9978\n",
      "Epoch 34/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0872 - binary_accuracy: 0.9994\n",
      "Epoch 35/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0823 - binary_accuracy: 0.9989\n",
      "Epoch 36/50\n",
      "6/6 [==============================] - 0s 19ms/step - loss: 0.0780 - binary_accuracy: 0.9992\n",
      "Epoch 37/50\n",
      "6/6 [==============================] - 0s 21ms/step - loss: 0.0741 - binary_accuracy: 1.0000\n",
      "Epoch 38/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0702 - binary_accuracy: 1.0000\n",
      "Epoch 39/50\n",
      "6/6 [==============================] - 0s 19ms/step - loss: 0.0667 - binary_accuracy: 1.0000\n",
      "Epoch 40/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0638 - binary_accuracy: 1.0000\n",
      "Epoch 41/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0607 - binary_accuracy: 1.0000\n",
      "Epoch 42/50\n",
      "6/6 [==============================] - 0s 21ms/step - loss: 0.0578 - binary_accuracy: 1.0000\n",
      "Epoch 43/50\n",
      "6/6 [==============================] - 0s 21ms/step - loss: 0.0554 - binary_accuracy: 1.0000\n",
      "Epoch 44/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0530 - binary_accuracy: 1.0000\n",
      "Epoch 45/50\n",
      "6/6 [==============================] - 0s 19ms/step - loss: 0.0507 - binary_accuracy: 1.0000\n",
      "Epoch 46/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0484 - binary_accuracy: 1.0000\n",
      "Epoch 47/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0464 - binary_accuracy: 1.0000\n",
      "Epoch 48/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0445 - binary_accuracy: 1.0000\n",
      "Epoch 49/50\n",
      "6/6 [==============================] - 0s 20ms/step - loss: 0.0428 - binary_accuracy: 1.0000\n",
      "Epoch 50/50\n",
      "6/6 [==============================] - 0s 19ms/step - loss: 0.0412 - binary_accuracy: 1.0000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.src.callbacks.History at 0x14a8c2360640>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)\n",
    "model.fit(X_train, y_train, batch_size=128, epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy:  1.0\n"
     ]
    }
   ],
   "source": [
    "out = model.predict(X_test, verbose=False)\n",
    "pred = np.array(np.round(out), dtype=int)\n",
    "print('accuracy: ', 1-np.sum(pred != y_test) / (np.prod(pred.shape)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: cnn_card_embedder/model/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: cnn_card_embedder/model/assets\n"
     ]
    }
   ],
   "source": [
    "model.save('cnn_card_embedder/model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: cnn_card_embedder/embedder/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: cnn_card_embedder/embedder/assets\n"
     ]
    }
   ],
   "source": [
    "embed.save('cnn_card_embedder/embedder')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n"
     ]
    }
   ],
   "source": [
    "card_embedder = tf.keras.models.load_model('cnn_card_embedder/embedder')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.12 64-bit ('tf')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "c42447f1c4240406d64c4df4cca87b5465b8a2bbd2ae4f1d6d833906715d3ac1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
