{
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "!pip install transformers accelerate -U"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1ymE0R3EuIZd",
        "outputId": "646863fd-8b08-4ff7-ccd4-a37d7a82dd4c"
      },
      "id": "1ymE0R3EuIZd",
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.32.1)\n",
            "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.22.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
            "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n",
            "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.3)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n",
            "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (2023.6.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (4.7.1)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.12)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.2)\n",
            "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (2.0.0)\n",
            "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (3.27.2)\n",
            "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10.0->accelerate) (16.0.6)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.2.0)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.4)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import reservoirtransformers\n",
        "import math\n",
        "import os\n",
        "import warnings\n",
        "from dataclasses import dataclass\n",
        "from typing import List, Optional, Tuple, Union\n",
        "\n",
        "import torch\n",
        "import torch.utils.checkpoint\n",
        "from torch import nn\n",
        "from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n",
        "import torch.nn.functional as F"
      ],
      "metadata": {
        "id": "mhmjveElujN2"
      },
      "id": "mhmjveElujN2",
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "id": "7eb58cec",
      "metadata": {
        "id": "7eb58cec"
      },
      "outputs": [],
      "source": [
        "from configuration import ReservoirTConfig\n",
        "\n",
        "configuration = ReservoirTConfig()\n",
        "\n",
        "configuration.output_size=7\n",
        "configuration.re_output_size=32\n",
        "configuration.max_sequence_length=1399\n",
        "configuration.sequence_length=332\n",
        "configuration.pred_len=60\n",
        "configuration.hidden_size=8\n",
        "configuration.num_attention_heads=4\n",
        "configuration.hidden_dropout_prob=0.0\n",
        "configuration.num_hidden_layers=8\n",
        "configuration.num_reservoirs = 10\n",
        "configuration.intermediate_size=256\n",
        "configuration.reservoir_size = [30, 15, 20, 25, 30, 35, 40, 45, 50, 50]\n",
        "configuration.spectral_radius = [0.6, 0.8, 0.55, 0.6, 0.5, 0.4, 0.3, 0.2, 0.81, 0.05]\n",
        "configuration.sparsity = [0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.15]\n",
        "configuration.leaky = [0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39]\n",
        "configuration.attention_probs_dropout_prob=0.0\n",
        "#bert_model = TabularBertForRegression(config=configuration).to(\"cpu\", dtype=float)\n",
        "model = reservoirtransformers.ReservoirTTimeSeries(config=configuration).to(\"cpu\", dtype=float)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "id": "0f6bf41e",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "id": "0f6bf41e",
        "outputId": "ba6b5f78-89ec-4d35-b6ab-e8394062e306"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\"\\nzero_col = np.zeros((X.shape[0], 1))\\n\\n# Concatenate the original array with the zero column along the second axis (columns)\\nX = np.hstack((X, zero_col))\\n#X =  dataset.drop(['ate'], axis = 1).values\\n\\n#X_train, X_test, y_train, y_test =train_test_split(X.values, y, test_size=0.2, shuffle=False)\\n\""
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 19
        }
      ],
      "source": [
        "import numpy as np\n",
        "# prepare data for lstm\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "from pandas import read_csv\n",
        "from pandas import DataFrame\n",
        "import random\n",
        "from sklearn.model_selection import train_test_split\n",
        "from pandas import concat\n",
        "from sklearn.preprocessing import LabelEncoder\n",
        "from sklearn.preprocessing import MinMaxScaler\n",
        "dataset= read_csv('/content/national_illness.csv')\n",
        "dataset=dataset.dropna()\n",
        "dataset = dataset.drop(['date'], axis = 1)\n",
        "dataset = dataset.dropna()\n",
        "#data = dataset.values[0:14000]\n",
        "\n",
        "\n",
        "y = dataset.OT.values\n",
        "\n",
        "\n",
        "X = dataset.values\n",
        "\n",
        "scaler = StandardScaler()\n",
        "X = scaler.fit_transform(X)\n",
        "\n",
        "\n",
        "\n",
        "#X=X[1:]\n",
        "\n",
        "#Reservoir_id = np.array([[0] * len(X[0])] + X[:-1].tolist())\n",
        "# Create a zero column of shape (100, 1)\n",
        "'''\n",
        "zero_col = np.zeros((X.shape[0], 1))\n",
        "\n",
        "# Concatenate the original array with the zero column along the second axis (columns)\n",
        "X = np.hstack((X, zero_col))\n",
        "#X =  dataset.drop(['ate'], axis = 1).values\n",
        "\n",
        "#X_train, X_test, y_train, y_test =train_test_split(X.values, y, test_size=0.2, shuffle=False)\n",
        "'''"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "id": "baffd6ba",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "baffd6ba",
        "outputId": "1eeb2feb-5d42-49eb-e44c-5d4ed5e561e6"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(966, 7)"
            ]
          },
          "metadata": {},
          "execution_count": 11
        }
      ],
      "source": [
        "X.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "6d0706c5",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84,
          "referenced_widgets": [
            "217144691d984ffb96ce490a6f607f5f",
            "d192a4ec3c2e4d1cbb27858e876ee55d",
            "6691c080e8254721845a7ae43815d6ac",
            "97050576a3574370a7f81b93be77045f",
            "1de60462c7744330ac86083d8e8be9ed",
            "b659737df91f43a3b127175b3a2e0e08",
            "864eb63501ec48c49f9d37430390f27c",
            "90c2bfe519bd4edba50bbd68cdeed1e9",
            "ae57a6bf4f9e47f487e0c7824d74cb74",
            "9b519074885b494182431f4eb9b1d3f8",
            "da85a3c3f0844a0fa424a4f9dadbc8a4"
          ]
        },
        "id": "6d0706c5",
        "outputId": "71a0ffb0-3c01-4832-a02c-425884c85799"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "creating seq\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/575 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "217144691d984ffb96ce490a6f607f5f"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([575, 332, 7])\n"
          ]
        }
      ],
      "source": [
        "from tqdm.auto import tqdm\n",
        "# 1. Preprocess the data into the required format\n",
        "def create_sequences(data, seq_length, pred_length):\n",
        "    sequences = []\n",
        "    targets = []\n",
        "    print('creating seq')\n",
        "    for i in tqdm(range(len(data) - seq_length - pred_length + 1)):\n",
        "\n",
        "        sequences.append(data[i:i+seq_length])\n",
        "        targets.append(data[i+seq_length:i+seq_length+pred_length])\n",
        "    return torch.tensor(sequences), torch.tensor(targets)\n",
        "\n",
        "X, y = create_sequences(data=X, seq_length=configuration.sequence_length, pred_length=configuration.pred_len)\n",
        "# Zeros tensor of shape [16941, 384, 1]\n",
        "print(X.shape)\n",
        "zeros = torch.zeros((X.size(0), X.size(1), 1), dtype=X.dtype)\n",
        "\n",
        "# Concatenate along the last dimension\n",
        "X = torch.cat((X, zeros), dim=-1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "3e13e0d0",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3e13e0d0",
        "outputId": "cb8297d2-2b66-4537-f8ea-a03334e7fd81"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(torch.Size([68841, 120, 12]), torch.Size([68841, 720, 7]))"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "X.shape, y.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "id": "7955f82a",
      "metadata": {
        "id": "7955f82a"
      },
      "outputs": [],
      "source": [
        "\n",
        "batch=64\n",
        "indices = np.arange(len(X))\n",
        "barrier = int(len(indices)/batch)*batch\n",
        "indices = indices[0:barrier]\n",
        "soft_border = int((configuration.sequence_length/batch))+16\n",
        "\n",
        "indices = [indices[i:i+batch] for i in range(0, len(indices), batch)]\n",
        "\n",
        "border1 = int(len(indices)*0.7)\n",
        "border2 = border1+int(len(indices)*0.1)\n",
        "border3 = border2+int(len(indices)*0.2)\n",
        "\n",
        "train_ind = indices[0:border1]\n",
        "val_ind = indices[border1-soft_border: border2]\n",
        "test_ind = indices[border2-soft_border: border3]\n",
        "\n",
        "random.shuffle(train_ind)\n",
        "random.shuffle(val_ind)\n",
        "#random.shuffle(test_ind)\n",
        "\n",
        "\n",
        "X_train = [X[item] for sublist in train_ind for item in sublist]\n",
        "y_train = [y[item] for sublist in train_ind for item in sublist]\n",
        "\n",
        "X_val = [X[item] for sublist in val_ind for item in sublist]\n",
        "y_val = [y[item] for sublist in val_ind for item in sublist]\n",
        "\n",
        "X_test = [X[item] for sublist in test_ind for item in sublist]\n",
        "y_test = [y[item] for sublist in test_ind for item in sublist]\n",
        "\n",
        "#train_indices, test_indices =train_test_split(indices,  test_size=0.2, shuffle=False)\n",
        "#indices = [item for sublist in indices for item in sublist]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "id": "26d5bda1",
      "metadata": {
        "id": "26d5bda1"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "\n",
        "class CustomDataset(Dataset):\n",
        "    def __init__(self, tokenized_inputs,  labels=None, pos=None):\n",
        "        self.tokenized_inputs = tokenized_inputs\n",
        "        self.labels = labels\n",
        "        self.pos = pos\n",
        "        self.id_list = None\n",
        "        self.re = None\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.tokenized_inputs)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        if self.labels is not None:\n",
        "            return {\n",
        "                \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
        "                \"labels_ids\": torch.tensor(self.labels[idx]),\n",
        "                #\"id\": torch.tensor(self.id_list[idx]),  # Include the id directly\n",
        "                #\"reservoir_ids\": torch.tensor(self.re[idx]),\n",
        "            }\n",
        "        else:\n",
        "            return {\n",
        "                \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
        "            }\n",
        "\n",
        "# Assuming you have X_train, y_train, X_test, y_test, trainpos, and testpos defined\n",
        "\n",
        "\n",
        "train_dataset = CustomDataset(X_train, y_train)\n",
        "\n",
        "test_dataset = CustomDataset(X_test, y_test)\n",
        "\n",
        "val_dataset = CustomDataset(X_val, y_val)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "id": "13806390",
      "metadata": {
        "id": "13806390"
      },
      "outputs": [],
      "source": [
        "from transformers import Trainer, TrainingArguments\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.cuda.amp import GradScaler\n",
        "class CustomTrainer(Trainer):\n",
        "    def __init__(self, *args, gradient_accumulation_steps=1, **kwargs):\n",
        "        super().__init__(*args, **kwargs)\n",
        "        self.gradient_accumulation_steps = gradient_accumulation_steps\n",
        "        self.scaler = GradScaler()\n",
        "\n",
        "    def training_step(self, model, inputs):\n",
        "        model.train()\n",
        "        inputs = self._prepare_inputs(inputs)\n",
        "        loss = self.compute_loss(model, inputs)\n",
        "\n",
        "        if self.args.n_gpu > 1:\n",
        "            loss = loss.mean()  # mean() to average on multi-gpu parallel training\n",
        "\n",
        "        loss = loss / self.gradient_accumulation_steps\n",
        "        self.scaler.scale(loss).backward()\n",
        "\n",
        "        return loss.detach()\n",
        "\n",
        "\n",
        "    def get_train_dataloader(self) -> DataLoader:\n",
        "        \"\"\"\n",
        "        Returns the training [`~torch.utils.data.DataLoader`].\n",
        "        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed\n",
        "        training if necessary) otherwise.\n",
        "        Subclass and override this method if you want to inject some custom behavior.\n",
        "        \"\"\"\n",
        "        if self.train_dataset is None:\n",
        "            raise ValueError(\"Trainer: training requires a train_dataset.\")\n",
        "\n",
        "        train_dataset = self.train_dataset\n",
        "\n",
        "\n",
        "        loader =  DataLoader(\n",
        "            train_dataset,\n",
        "            batch_size=self._train_batch_size,\n",
        "            drop_last=self.args.dataloader_drop_last,\n",
        "            shuffle = False,\n",
        "        )\n",
        "        return loader\n",
        "\n",
        "#bert_model = TabularBertForSequenceClassification(config=configuration).to(\"cpu\", dtype=float)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "id": "cb4d10de",
      "metadata": {
        "scrolled": true,
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "cb4d10de",
        "outputId": "c430f6bd-66c1-4796-95aa-edf62c2fb759"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='750' max='750' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [750/750 17:52, Epoch 150/150]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>R2 Score</th>\n",
              "      <th>Mse</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>50</td>\n",
              "      <td>0.103500</td>\n",
              "      <td>0.061224</td>\n",
              "      <td>0.869211</td>\n",
              "      <td>0.061224</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>100</td>\n",
              "      <td>0.054500</td>\n",
              "      <td>0.051057</td>\n",
              "      <td>0.890931</td>\n",
              "      <td>0.051057</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>150</td>\n",
              "      <td>0.043100</td>\n",
              "      <td>0.041290</td>\n",
              "      <td>0.911796</td>\n",
              "      <td>0.041290</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>200</td>\n",
              "      <td>0.036100</td>\n",
              "      <td>0.031987</td>\n",
              "      <td>0.931667</td>\n",
              "      <td>0.031987</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>250</td>\n",
              "      <td>0.030600</td>\n",
              "      <td>0.029776</td>\n",
              "      <td>0.936391</td>\n",
              "      <td>0.029776</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>300</td>\n",
              "      <td>0.029300</td>\n",
              "      <td>0.026429</td>\n",
              "      <td>0.943541</td>\n",
              "      <td>0.026429</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>350</td>\n",
              "      <td>0.026200</td>\n",
              "      <td>0.025918</td>\n",
              "      <td>0.944634</td>\n",
              "      <td>0.025918</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>400</td>\n",
              "      <td>0.025700</td>\n",
              "      <td>0.024748</td>\n",
              "      <td>0.947132</td>\n",
              "      <td>0.024748</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>450</td>\n",
              "      <td>0.024800</td>\n",
              "      <td>0.024320</td>\n",
              "      <td>0.948047</td>\n",
              "      <td>0.024320</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>500</td>\n",
              "      <td>0.023200</td>\n",
              "      <td>0.021711</td>\n",
              "      <td>0.953620</td>\n",
              "      <td>0.021711</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>550</td>\n",
              "      <td>0.021700</td>\n",
              "      <td>0.021133</td>\n",
              "      <td>0.954854</td>\n",
              "      <td>0.021133</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>600</td>\n",
              "      <td>0.021000</td>\n",
              "      <td>0.020447</td>\n",
              "      <td>0.956320</td>\n",
              "      <td>0.020447</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>650</td>\n",
              "      <td>0.020100</td>\n",
              "      <td>0.019787</td>\n",
              "      <td>0.957730</td>\n",
              "      <td>0.019787</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>700</td>\n",
              "      <td>0.019700</td>\n",
              "      <td>0.019494</td>\n",
              "      <td>0.958356</td>\n",
              "      <td>0.019494</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>750</td>\n",
              "      <td>0.019500</td>\n",
              "      <td>0.019487</td>\n",
              "      <td>0.958370</td>\n",
              "      <td>0.019487</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"inputs_embeds\": torch.tensor(self.tokenized_inputs[idx]),\n",
            "<ipython-input-14-5df45e9992a5>:19: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
            "  \"labels_ids\": torch.tensor(self.labels[idx]),\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "TrainOutput(global_step=750, training_loss=0.0332719509601593, metrics={'train_runtime': 1074.0487, 'train_samples_per_second': 44.691, 'train_steps_per_second': 0.698, 'total_flos': 0.0, 'train_loss': 0.0332719509601593, 'epoch': 150.0})"
            ]
          },
          "metadata": {},
          "execution_count": 20
        }
      ],
      "source": [
        "from transformers import BertForSequenceClassification, Trainer, TrainingArguments\n",
        "from sklearn.metrics import mean_squared_error\n",
        "from transformers import Trainer, TrainingArguments\n",
        "from transformers import EarlyStoppingCallback, IntervalStrategy\n",
        "from sklearn.metrics import r2_score, accuracy_score\n",
        "import numpy as np\n",
        "\n",
        "def compute_metrics1(p):\n",
        "\n",
        "    preds = p.predictions.flatten()\n",
        "    labels = p.label_ids.flatten()\n",
        "\n",
        "    r2 = r2_score(labels, preds)\n",
        "    mse = mean_squared_error(labels, preds)\n",
        "\n",
        "    return {\"r2_score\": r2, \"mse\": mse}\n",
        "\n",
        "def compute_metrics_classification(p):\n",
        "    preds = np.argmax(p.predictions , axis=1)\n",
        "    labels = p.label_ids\n",
        "\n",
        "    accuracy = accuracy_score(labels, preds)\n",
        "    return {\"accuracy_score\": accuracy}\n",
        "\n",
        "def compute_metrics(p):\n",
        "\n",
        "\n",
        "    prediction_scores, labels_ids = p\n",
        "    #print('here')\n",
        "    #print(prediction_scores)\n",
        "\n",
        "    mask = labels_ids != 100\n",
        "    #print(mask)\n",
        "    masked_predictions = prediction_scores[mask]\n",
        "    masked_labels = labels_ids[mask]\n",
        "\n",
        "    mse = mean_squared_error(masked_predictions, masked_labels)\n",
        "    return {\"mse\": mse}\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "    output_dir='./results_task1',\n",
        "    num_train_epochs=150,\n",
        "    label_names=[\"labels_ids\"],\n",
        "    disable_tqdm = False,\n",
        "    #label_names=[\"labels_mask\"],\n",
        "    do_eval=True,\n",
        "    learning_rate=0.001,\n",
        "    per_device_train_batch_size=batch,\n",
        "    per_device_eval_batch_size=batch,\n",
        "    logging_dir='./logs',\n",
        "    logging_strategy=\"steps\",\n",
        "    logging_steps=50,\n",
        "    evaluation_strategy=\"steps\",\n",
        "    eval_steps = 50,\n",
        "    save_strategy=\"steps\",\n",
        "    save_steps=50,\n",
        "\n",
        "    save_total_limit=2,\n",
        "    load_best_model_at_end=True,\n",
        ")\n",
        "\n",
        "trainer = CustomTrainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    train_dataset=train_dataset,\n",
        "    eval_dataset=val_dataset,\n",
        "    compute_metrics=compute_metrics1, #compute_metrics1,#compute_metrics_classification,\n",
        "    callbacks = [EarlyStoppingCallback(early_stopping_patience=10)]\n",
        ")\n",
        "\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "cac7d37e",
      "metadata": {
        "id": "cac7d37e"
      },
      "outputs": [],
      "source": [
        "model = reservoirtransformers.ReservoirTTimeSeries.from_pretrained(\"/content/results_task1/checkpoint-700\", config=configuration).to(\"cuda\", dtype=float)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "id": "827f3400",
      "metadata": {
        "id": "827f3400"
      },
      "outputs": [],
      "source": [
        "cnt = 0\n",
        "ln = len(X_test)\n",
        "y_pred = []\n",
        "y_test1 = []\n",
        "while cnt < ln:\n",
        "    #print(cnt, ln)\n",
        "    input_ids = torch.stack(X_test[cnt:cnt+batch], dim=0)\n",
        "    #y_test1 = y_test1 + [k.detach().numpy().flatten() for k in y_test[cnt:cnt+64]]\n",
        "\n",
        "    output = model(inputs_embeds = input_ids.to(model.device))['logits']\n",
        "    y_pred = y_pred + list(output.cpu().detach().numpy().reshape(output.size(0), -1))\n",
        "    #y_test = y_test + labels_ids\n",
        "\n",
        "    cnt=cnt+batch\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "id": "1fae404b",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1fae404b",
        "outputId": "6c61fe9e-28c4-4f7a-ff2b-f966c1467d21"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.11132165051813275"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ],
      "source": [
        "from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
        "y_test1 = [i.detach().numpy().flatten() for i in y_test]\n",
        "\n",
        "mse = mean_squared_error(y_test1, y_pred)\n",
        "mse"
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "AOenxRwjeFH6"
      },
      "id": "AOenxRwjeFH6",
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.4"
    },
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "217144691d984ffb96ce490a6f607f5f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_d192a4ec3c2e4d1cbb27858e876ee55d",
              "IPY_MODEL_6691c080e8254721845a7ae43815d6ac",
              "IPY_MODEL_97050576a3574370a7f81b93be77045f"
            ],
            "layout": "IPY_MODEL_1de60462c7744330ac86083d8e8be9ed"
          }
        },
        "d192a4ec3c2e4d1cbb27858e876ee55d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_b659737df91f43a3b127175b3a2e0e08",
            "placeholder": "​",
            "style": "IPY_MODEL_864eb63501ec48c49f9d37430390f27c",
            "value": "100%"
          }
        },
        "6691c080e8254721845a7ae43815d6ac": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_90c2bfe519bd4edba50bbd68cdeed1e9",
            "max": 575,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_ae57a6bf4f9e47f487e0c7824d74cb74",
            "value": 575
          }
        },
        "97050576a3574370a7f81b93be77045f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_9b519074885b494182431f4eb9b1d3f8",
            "placeholder": "​",
            "style": "IPY_MODEL_da85a3c3f0844a0fa424a4f9dadbc8a4",
            "value": " 575/575 [00:00&lt;00:00, 13674.31it/s]"
          }
        },
        "1de60462c7744330ac86083d8e8be9ed": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "b659737df91f43a3b127175b3a2e0e08": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "864eb63501ec48c49f9d37430390f27c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "90c2bfe519bd4edba50bbd68cdeed1e9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "ae57a6bf4f9e47f487e0c7824d74cb74": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "9b519074885b494182431f4eb9b1d3f8": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "da85a3c3f0844a0fa424a4f9dadbc8a4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}