{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4c1b9294",
   "metadata": {},
   "source": [
    "# Libs, Automata and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e0dd236e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "84e26ff8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "state          distance  | t_shirt_top  trouser      pullover     dress        coat         sandal       shirt        sneaker      bag          ankle_boot   \n",
      "-------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
      "1               2         | 8            5            4            6            2            3            8            3            2            3            \n",
      "2(Deadlock)     inf       | 2            2            2            2            2            2            2            2            2            2            \n",
      "3               1         | 2            2            2            7            2            2            2            2            2            2            \n",
      "4               2         | 2            6            2            2            11           2            2            2            10           2            \n",
      "5               2         | 9            2            6            2            2            14           9            14           2            14           \n",
      "6               1         | 2            2            2            2            13           7            2            7            12           7            \n",
      "7(Accepting)    0         | 2            2            2            2            17           2            2            2            16           2            \n",
      "8               2         | 2            9            4            2            11           2            2            2            10           2            \n",
      "9               1         | 2            2            6            2            13           15           2            15           12           15           \n",
      "10              2         | 2            12           2            2            2            2            2            2            2            2            \n",
      "11              2         | 2            13           2            2            2            2            2            2            10           2            \n",
      "12              1         | 2            2            2            2            2            16           2            16           2            16           \n",
      "13              1         | 2            2            2            2            2            17           2            17           12           17           \n",
      "14              1         | 15           2            7            2            2            2            15           2            2            2            \n",
      "15(Accepting)   0         | 2            2            7            2            17           2            2            2            16           2            \n",
      "16(Accepting)   0         | 2            2            2            2            2            2            2            2            2            2            \n",
      "17(Accepting)   0         | 2            2            2            2            2            2            2            2            16           2            \n"
     ]
    }
   ],
   "source": [
    "clothing_items = [\"t_shirt_top\", \"trouser\", \"pullover\", \"dress\", \"coat\", \"sandal\", \"shirt\", \"sneaker\", \"bag\", \"ankle_boot\"]\n",
    "\n",
    "automata = import_automata(\"../data/automata/ordered_fashion_mnist_automata.json\")\n",
    "print_automata(automata, clothing_items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5971e4ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset, train_custom_indices, train_custom_labels = import_ordered_fashion_mnist( \"../data/Ordered Fashion MNIST\", \"../data/Ordered Fashion MNIST/train_ordered_fashion_mnist.csv\", train=True)\n",
    "test_dataset, test_custom_indices, test_custom_labels = import_ordered_fashion_mnist( \"../data/Ordered Fashion MNIST\", \"../data/Ordered Fashion MNIST/test_ordered_fashion_mnist.csv\", train=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "676b3681",
   "metadata": {},
   "source": [
    "# Training CNN-LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6addff5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CNNLSTM()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a0b9a0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cnn_lstm_with_frozen_cnn(model, train_dataset, train_custom_indices, train_custom_labels, '../data/CNN/fashionConvNet_model_on_original_fmnist.pth', num_epochs=30, batch_size=16, lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "62a57317",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Single label accuracy: 0.9407\n",
      "Sequence accuracy:    0.7469\n"
     ]
    }
   ],
   "source": [
    "test_cnn_lstm(model, test_dataset, test_custom_indices, test_custom_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e598a661",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing with k=2\n",
      "Single label accuracy: 0.9350\n",
      "Sequence accuracy:    0.7737\n",
      "Testing with k=3\n",
      "Single label accuracy: 0.9405\n",
      "Sequence accuracy:    0.7798\n",
      "Testing with k=4\n",
      "Single label accuracy: 0.9449\n",
      "Sequence accuracy:    0.7840\n",
      "Testing with k=5\n",
      "Single label accuracy: 0.9453\n",
      "Sequence accuracy:    0.7850\n",
      "Testing with k=6\n",
      "Single label accuracy: 0.9466\n",
      "Sequence accuracy:    0.7860\n",
      "Testing with k=7\n",
      "Single label accuracy: 0.9490\n",
      "Sequence accuracy:    0.7891\n",
      "Testing with k=8\n",
      "Single label accuracy: 0.9490\n",
      "Sequence accuracy:    0.7891\n",
      "Testing with k=9\n",
      "Single label accuracy: 0.9497\n",
      "Sequence accuracy:    0.7901\n",
      "Testing with k=10\n",
      "Single label accuracy: 0.9497\n",
      "Sequence accuracy:    0.7901\n"
     ]
    }
   ],
   "source": [
    "for k in range(2, 11):\n",
    "    print(f\"Testing with k={k}\")\n",
    "    test_cnn_lstm_with_automata(test_custom_indices, model, test_dataset, automata, k=k)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "trident",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
