{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FC2Encoder(nn.Module):\n",
    "    def __init__(self, c_in, patch_len,  d_model=128, shared_embedding=True, **kwargs):\n",
    "        super().__init__()\n",
    "        self.n_vars = c_in\n",
    "        self.patch_len = patch_len\n",
    "        self.d_model = d_model\n",
    "        self.shared_embedding = shared_embedding        \n",
    "        self.act = nn.ReLU(inplace=True)\n",
    "        # Input encoding: projection of feature vectors onto a d-dim vector space\n",
    "        if not shared_embedding: \n",
    "            self.W_P1 = nn.ModuleList()\n",
    "            self.W_P2 = nn.ModuleList()\n",
    "            for _ in range(self.n_vars): \n",
    "                self.W_P1.append(nn.Linear(patch_len, d_model))\n",
    "                self.W_P2.append(nn.Linear(d_model, patch_len))\n",
    "        else:\n",
    "            self.W_P1 = nn.Linear(patch_len, d_model)      \n",
    "            self.W_P2 = nn.Linear(d_model, patch_len)      \n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        x: tensor [bs x num_patch x nvars x patch_len]\n",
    "        \"\"\"\n",
    "        bs, num_patch, n_vars, patch_len = x.shape\n",
    "        # Input encoding\n",
    "        if not self.shared_embedding:\n",
    "            x_out = []\n",
    "            for i in range(n_vars):\n",
    "                z = self.W_P1[i](x[:,:,i,:])\n",
    "                z = self.act(z)\n",
    "                z = self.W_P2[i](z) # ??\n",
    "                x_out.append(z)\n",
    "            x = torch.stack(x_out, dim=2)\n",
    "        else:\n",
    "            x = self.W_P1(x)                                                      # x: [bs x num_patch x nvars x d_model]\n",
    "            x = self.act(x)\n",
    "            x = self.W_P2(x)                                                      # x: [bs x num_patch x nvars x d_model]\n",
    "        # [64, 42, 7, 128]\n",
    "        x = x.transpose(1,2)                                                     # x: [bs x nvars x num_patch x d_model]        \n",
    "        # [64, 7, 42, 128])\n",
    "        x = x.permute(0,2,1,3)\n",
    "        # [64, 7, 128, 42]\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PretrainHead(nn.Module):\n",
    "    def __init__(self, d_model, patch_len, dropout):\n",
    "        super().__init__()\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.linear = nn.Linear(d_model, patch_len)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        x: tensor [bs x nvars x d_model x num_patch]\n",
    "        output: tensor [bs x nvars x num_patch x patch_len]\n",
    "        \"\"\"\n",
    "\n",
    "        x = x.transpose(2,3)                     # [bs x nvars x num_patch x d_model]\n",
    "        x = self.linear( self.dropout(x) )      # [bs x nvars x num_patch x patch_len]\n",
    "        x = x.permute(0,2,1,3)                  # [bs x num_patch x nvars x patch_len]\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = FC2Encoder(c_in=7,patch_len=12,d_model=64)\n",
    "head = PretrainHead(d_model=64, patch_len=12, dropout=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=8\n",
    "num_patch =42\n",
    "n_vars = 7  \n",
    "patch_len=12\n",
    "\n",
    "x = torch.randn(bs,num_patch,n_vars,patch_len)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([8, 42, 7, 12])\n",
      "torch.Size([8, 42, 7, 12])\n"
     ]
    }
   ],
   "source": [
    "x_hat = model(x)\n",
    "print(x.shape)\n",
    "print(x_hat.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[-8.9403e-02, -3.5433e-01,  1.4283e-01,  ..., -2.2666e-01,\n",
       "            3.0635e-02, -3.7658e-01],\n",
       "          [ 3.7186e-01, -1.5576e-01, -1.1272e-04,  ...,  1.9290e-01,\n",
       "           -7.6845e-02, -2.2469e-02],\n",
       "          [ 2.5821e-01,  9.2790e-03, -1.7310e-01,  ...,  2.4596e-01,\n",
       "           -7.8483e-02,  1.2535e-01],\n",
       "          ...,\n",
       "          [ 4.1165e-01, -3.4394e-01, -6.6287e-02,  ...,  2.0222e-01,\n",
       "            5.3038e-03, -4.7927e-01],\n",
       "          [-5.2598e-01, -2.2799e-01,  4.1862e-01,  ...,  3.7808e-02,\n",
       "            3.4504e-01, -7.0980e-01],\n",
       "          [ 1.6345e-02,  2.0064e-02,  1.9610e-01,  ...,  5.3741e-01,\n",
       "            7.3447e-01, -1.8715e-01]],\n",
       "\n",
       "         [[ 2.5248e-01, -2.3375e-02, -1.3653e-01,  ..., -1.0224e-01,\n",
       "            9.7010e-02, -5.4321e-02],\n",
       "          [-2.0465e-01,  9.7795e-02, -1.3769e-01,  ...,  1.3499e-01,\n",
       "            4.2820e-01, -1.4831e-01],\n",
       "          [-1.8063e-01, -1.7593e-01,  9.6050e-02,  ...,  1.0404e-02,\n",
       "            1.7845e-01, -2.0802e-01],\n",
       "          ...,\n",
       "          [-3.0736e-01, -2.3962e-01,  3.2502e-01,  ...,  1.7215e-01,\n",
       "            3.5867e-01, -3.8662e-01],\n",
       "          [-3.7137e-01, -2.4694e-01,  3.1087e-01,  ...,  1.1619e-02,\n",
       "            2.7243e-01, -5.6619e-01],\n",
       "          [ 8.5210e-02,  9.0755e-02, -1.2650e-01,  ...,  3.2149e-01,\n",
       "            8.5327e-02, -1.1749e-01]],\n",
       "\n",
       "         [[ 1.3276e-01, -5.2976e-02, -4.5115e-03,  ...,  1.1166e-01,\n",
       "            2.9793e-02, -1.4485e-01],\n",
       "          [ 2.1621e-01, -4.1801e-01,  1.0044e-01,  ...,  2.1016e-01,\n",
       "           -6.1569e-02, -2.5898e-01],\n",
       "          [-1.5469e-01, -2.4114e-01, -6.2858e-02,  ...,  1.5206e-01,\n",
       "           -3.8228e-02, -2.8447e-01],\n",
       "          ...,\n",
       "          [-2.2740e-01, -1.5985e-01,  1.9741e-01,  ...,  1.4837e-01,\n",
       "            1.8301e-01, -3.4753e-01],\n",
       "          [ 1.8831e-01, -4.5104e-01,  8.3307e-02,  ...,  2.7889e-01,\n",
       "            1.9779e-01, -4.4104e-01],\n",
       "          [-1.3325e-01, -3.1734e-01,  1.9063e-01,  ...,  2.0091e-01,\n",
       "           -1.3009e-01, -4.0139e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.5667e-01, -2.2458e-01,  1.8515e-01,  ..., -7.8965e-02,\n",
       "            5.4220e-02, -5.6887e-01],\n",
       "          [-1.1172e-02,  5.7363e-02, -1.3214e-01,  ...,  9.6724e-02,\n",
       "            1.6736e-01, -6.8752e-02],\n",
       "          [ 5.7901e-01, -3.0209e-01,  1.9962e-01,  ...,  6.2320e-02,\n",
       "            4.8759e-02, -4.1665e-01],\n",
       "          ...,\n",
       "          [-1.9690e-01, -2.0746e-01,  1.4071e-01,  ...,  1.0494e-01,\n",
       "            1.4686e-01, -4.4981e-01],\n",
       "          [ 2.5978e-01, -3.9497e-01,  2.3829e-01,  ...,  2.2958e-01,\n",
       "           -2.1703e-02, -2.7167e-01],\n",
       "          [ 5.2336e-02, -2.1554e-01, -5.0392e-02,  ...,  1.8316e-01,\n",
       "           -7.2300e-02, -1.3969e-01]],\n",
       "\n",
       "         [[ 4.6489e-01, -5.9151e-01,  5.5386e-02,  ...,  3.3980e-01,\n",
       "            1.2944e-01, -2.6029e-01],\n",
       "          [ 9.1259e-02, -4.2757e-01,  4.7471e-02,  ...,  1.6649e-01,\n",
       "            2.2965e-01, -1.7923e-01],\n",
       "          [-5.4312e-01, -2.5526e-01,  2.9484e-01,  ..., -1.2184e-01,\n",
       "            1.2624e-01, -4.4561e-01],\n",
       "          ...,\n",
       "          [ 1.6282e-03, -1.6363e-01,  3.7006e-01,  ..., -1.7061e-03,\n",
       "            2.7178e-01, -3.8396e-01],\n",
       "          [-1.4215e-01, -2.0147e-01,  1.5067e-01,  ..., -7.1187e-02,\n",
       "            2.7008e-01, -4.0522e-01],\n",
       "          [ 9.2469e-02, -2.2202e-01,  1.4102e-01,  ...,  3.1030e-01,\n",
       "            2.3090e-01, -1.2817e-01]],\n",
       "\n",
       "         [[-3.0645e-01, -1.3544e-01,  2.7533e-01,  ...,  4.0767e-02,\n",
       "            2.4171e-01, -1.8513e-01],\n",
       "          [ 6.9909e-02, -3.8334e-01,  2.2484e-01,  ...,  1.3506e-01,\n",
       "            1.3696e-01, -3.9538e-01],\n",
       "          [-2.7162e-01, -6.5567e-02,  2.8488e-01,  ..., -2.0186e-01,\n",
       "           -9.5526e-02, -2.6294e-01],\n",
       "          ...,\n",
       "          [ 3.3735e-01, -2.9220e-01,  5.9808e-02,  ..., -4.4062e-02,\n",
       "           -1.3832e-01, -3.5009e-01],\n",
       "          [ 2.7952e-01, -5.5331e-01, -9.2047e-02,  ...,  1.3483e-01,\n",
       "           -2.8216e-01, -3.0503e-01],\n",
       "          [-3.3701e-02, -6.6606e-02, -1.6031e-02,  ...,  1.4441e-01,\n",
       "            2.0764e-01, -3.0103e-01]]],\n",
       "\n",
       "\n",
       "        [[[ 3.1031e-01, -4.2169e-01, -9.9515e-02,  ...,  1.6613e-01,\n",
       "           -2.4158e-01, -2.0335e-01],\n",
       "          [-4.0328e-01, -2.2370e-01,  1.6740e-01,  ..., -4.9789e-02,\n",
       "            1.1159e-01, -3.8654e-01],\n",
       "          [-8.4383e-02, -2.0079e-02,  1.4638e-01,  ..., -1.7751e-02,\n",
       "            1.8943e-01, -5.1139e-03],\n",
       "          ...,\n",
       "          [ 1.9074e-01, -1.5206e-01,  6.0615e-02,  ...,  1.4822e-01,\n",
       "            6.3257e-02, -1.7751e-01],\n",
       "          [ 2.0096e-01, -5.8735e-01,  2.2324e-03,  ...,  1.7335e-01,\n",
       "            8.7184e-02, -2.2545e-01],\n",
       "          [-2.2418e-01,  2.9650e-02,  1.4126e-01,  ...,  1.2900e-01,\n",
       "            1.3765e-02, -2.7511e-01]],\n",
       "\n",
       "         [[ 1.6262e-01, -1.6960e-01,  2.4113e-01,  ...,  6.9796e-02,\n",
       "            4.9710e-01, -2.6327e-01],\n",
       "          [ 3.1163e-01, -2.1390e-01,  9.3061e-02,  ..., -2.5701e-01,\n",
       "           -1.8800e-01, -6.7944e-01],\n",
       "          [ 1.3094e-01, -1.7834e-01,  2.9224e-02,  ...,  3.4413e-01,\n",
       "           -6.1257e-03, -1.5247e-01],\n",
       "          ...,\n",
       "          [ 4.2749e-03, -2.3387e-01,  3.2417e-03,  ..., -9.8155e-02,\n",
       "            5.9106e-02, -3.0657e-01],\n",
       "          [-9.5016e-02,  2.8307e-02,  1.8044e-01,  ...,  2.4020e-01,\n",
       "            2.1055e-01, -1.9322e-01],\n",
       "          [-5.1084e-01, -2.4540e-01,  1.0357e-01,  ..., -4.6259e-02,\n",
       "            1.2909e-01, -3.6974e-01]],\n",
       "\n",
       "         [[-2.9371e-03, -1.4102e-01,  2.3721e-01,  ...,  2.3661e-02,\n",
       "            5.5452e-02, -1.5020e-01],\n",
       "          [ 8.1383e-02, -2.2331e-01,  1.0852e-01,  ...,  3.2749e-01,\n",
       "            1.4611e-02,  8.0208e-02],\n",
       "          [-4.2331e-01, -2.6790e-01,  2.7601e-01,  ..., -1.6343e-01,\n",
       "            2.5084e-01, -7.3210e-01],\n",
       "          ...,\n",
       "          [-5.3952e-01,  2.5421e-01, -3.3563e-02,  ..., -1.8242e-01,\n",
       "            5.3937e-01, -3.1347e-01],\n",
       "          [-1.9461e-01,  1.9607e-02,  8.4832e-02,  ...,  9.5553e-02,\n",
       "            1.2653e-01, -1.8944e-01],\n",
       "          [ 5.1225e-02, -4.7103e-01,  1.7638e-01,  ...,  1.8473e-01,\n",
       "           -5.1577e-03, -2.5818e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 1.8208e-01, -2.2781e-01,  2.8086e-01,  ...,  1.8755e-01,\n",
       "            1.4383e-01, -2.1261e-01],\n",
       "          [ 1.0665e-01, -3.8374e-01,  1.5743e-01,  ..., -4.8052e-02,\n",
       "            1.5648e-01, -1.4524e-01],\n",
       "          [-4.5316e-01, -3.2344e-01,  1.6220e-01,  ..., -3.8766e-02,\n",
       "            1.1738e-01, -6.5967e-01],\n",
       "          ...,\n",
       "          [-2.2019e-01, -1.5886e-01,  2.5567e-01,  ..., -1.9022e-02,\n",
       "            1.9480e-02, -5.8032e-01],\n",
       "          [ 2.4929e-01,  5.4586e-02, -2.5587e-01,  ..., -4.8621e-03,\n",
       "           -2.4426e-01,  7.0167e-02],\n",
       "          [ 5.0042e-01, -2.5598e-01, -2.4811e-02,  ...,  1.9934e-01,\n",
       "            2.3626e-01, -1.2735e-01]],\n",
       "\n",
       "         [[-9.9028e-02, -2.3922e-01,  2.3031e-01,  ...,  1.6588e-01,\n",
       "            5.2585e-01, -1.4879e-01],\n",
       "          [-7.2392e-01,  1.9377e-03,  5.0340e-01,  ...,  9.7689e-04,\n",
       "            2.4785e-01, -9.5579e-01],\n",
       "          [-1.8752e-01, -1.4306e-01,  8.9758e-02,  ...,  5.1803e-02,\n",
       "           -2.3336e-03, -2.9452e-01],\n",
       "          ...,\n",
       "          [-1.5772e-01,  1.5636e-02, -1.5192e-01,  ...,  1.7358e-01,\n",
       "            4.4287e-01, -2.5008e-01],\n",
       "          [-3.3574e-01, -3.6910e-01,  1.9237e-01,  ...,  3.9468e-03,\n",
       "            2.7638e-01, -4.6183e-01],\n",
       "          [ 4.9640e-01, -4.0076e-01,  1.6892e-01,  ...,  4.2551e-01,\n",
       "            2.7451e-02, -2.5648e-01]],\n",
       "\n",
       "         [[ 1.2571e-01, -3.7431e-02, -6.8012e-02,  ...,  1.4248e-01,\n",
       "           -1.0832e-01,  3.9125e-02],\n",
       "          [ 1.0689e-01, -2.5974e-01,  1.0615e-01,  ..., -6.4870e-03,\n",
       "            9.3929e-03, -3.6153e-01],\n",
       "          [-2.4058e-01,  1.6408e-01, -5.3426e-02,  ..., -3.8627e-01,\n",
       "            4.8652e-02, -3.5219e-01],\n",
       "          ...,\n",
       "          [ 5.1862e-01, -2.8926e-01, -9.1261e-02,  ..., -7.3495e-03,\n",
       "           -3.0833e-01, -3.8697e-01],\n",
       "          [-2.4440e-01, -5.7323e-02,  2.0769e-01,  ...,  1.1791e-01,\n",
       "            1.5722e-01, -1.6100e-01],\n",
       "          [ 2.9860e-01, -4.6335e-01, -1.0500e-01,  ...,  3.9517e-01,\n",
       "           -6.2479e-02, -1.5349e-01]]],\n",
       "\n",
       "\n",
       "        [[[-1.1878e-01, -1.1876e-01,  1.2248e-01,  ...,  5.1062e-02,\n",
       "            1.3422e-02, -6.6154e-01],\n",
       "          [-3.0072e-01, -2.1284e-01,  4.1849e-02,  ..., -6.3308e-02,\n",
       "            3.8692e-02, -3.6043e-01],\n",
       "          [-2.7009e-01, -5.8381e-01,  3.9408e-01,  ...,  6.0796e-02,\n",
       "            3.9342e-01, -5.6193e-01],\n",
       "          ...,\n",
       "          [-1.9798e-01,  1.8148e-02,  4.6478e-02,  ...,  1.4275e-01,\n",
       "            4.1575e-02, -1.4965e-01],\n",
       "          [-2.7918e-01, -2.2545e-01,  1.5789e-01,  ..., -2.7596e-01,\n",
       "           -1.0390e-02, -4.8947e-01],\n",
       "          [-2.5011e-01, -2.7812e-01,  2.2182e-01,  ...,  1.0706e-01,\n",
       "            1.5091e-01, -5.2015e-01]],\n",
       "\n",
       "         [[-2.8018e-01, -1.4497e-01,  1.7410e-01,  ..., -6.6085e-02,\n",
       "            1.2464e-01, -3.7399e-01],\n",
       "          [ 2.1456e-02, -1.7604e-01,  8.5156e-02,  ...,  9.9052e-02,\n",
       "           -2.3320e-02, -3.1358e-01],\n",
       "          [-3.1760e-02, -1.4795e-01,  5.2993e-02,  ..., -3.4973e-01,\n",
       "            2.6997e-01, -1.6840e-01],\n",
       "          ...,\n",
       "          [-2.4240e-01,  6.8158e-02,  1.1867e-01,  ...,  5.1530e-04,\n",
       "            1.9275e-02, -4.5134e-01],\n",
       "          [ 1.1354e-01, -9.3986e-02, -1.2860e-01,  ..., -2.0393e-02,\n",
       "            1.3368e-02, -2.3637e-01],\n",
       "          [-1.6979e-01, -6.6096e-02,  2.4409e-01,  ...,  2.5313e-01,\n",
       "            3.5438e-01, -4.1879e-01]],\n",
       "\n",
       "         [[ 9.8551e-02, -3.3630e-01,  4.2280e-02,  ...,  5.3071e-03,\n",
       "            7.6842e-02, -2.4965e-01],\n",
       "          [-3.8365e-01, -4.2979e-01,  2.4944e-01,  ..., -6.8884e-02,\n",
       "            3.7846e-01, -5.8473e-01],\n",
       "          [-4.4990e-01, -3.3646e-01,  2.1499e-01,  ...,  3.7781e-01,\n",
       "            2.8057e-01, -5.1865e-01],\n",
       "          ...,\n",
       "          [ 5.0361e-02,  6.9150e-02, -1.3564e-01,  ...,  1.1821e-01,\n",
       "           -8.9925e-02, -2.9316e-03],\n",
       "          [ 1.2687e-01, -4.5032e-01,  1.3220e-01,  ...,  9.4394e-02,\n",
       "            1.5475e-01, -2.2406e-01],\n",
       "          [-3.2047e-01, -3.4949e-01,  1.6642e-01,  ..., -1.6282e-01,\n",
       "            2.3358e-01, -5.4362e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.5034e-02, -2.4495e-01, -2.2647e-02,  ...,  3.5685e-01,\n",
       "           -1.3195e-01, -1.0915e-01],\n",
       "          [-2.6098e-01, -2.4859e-01,  8.6908e-02,  ..., -3.4467e-02,\n",
       "            2.7817e-01, -3.3705e-01],\n",
       "          [-8.4617e-02, -1.5070e-01,  1.2933e-01,  ...,  2.1612e-01,\n",
       "            3.4238e-02, -1.3999e-01],\n",
       "          ...,\n",
       "          [ 2.7438e-01, -4.8637e-01,  9.1514e-02,  ...,  3.3551e-01,\n",
       "            7.3375e-02, -2.6785e-01],\n",
       "          [-2.4717e-01, -3.8241e-01,  2.1035e-01,  ...,  1.8048e-01,\n",
       "            1.8331e-01, -4.9448e-01],\n",
       "          [-2.1069e-01, -1.4401e-01,  2.8427e-01,  ..., -4.9342e-02,\n",
       "            1.0414e-01, -5.3848e-01]],\n",
       "\n",
       "         [[-1.4477e-01, -2.8325e-01,  4.7989e-02,  ...,  4.2794e-02,\n",
       "            4.7066e-02, -3.5370e-01],\n",
       "          [ 2.9244e-01, -5.6705e-01, -4.2419e-02,  ...,  5.5824e-01,\n",
       "            2.7560e-01, -1.5307e-01],\n",
       "          [-2.0052e-01, -4.8609e-01,  1.7641e-01,  ...,  2.4216e-01,\n",
       "            2.6205e-02, -3.6005e-01],\n",
       "          ...,\n",
       "          [-2.3508e-01, -3.3126e-01,  1.9385e-01,  ..., -4.3818e-02,\n",
       "           -1.3189e-02, -7.1941e-01],\n",
       "          [ 5.9029e-02, -2.7786e-02, -6.4542e-02,  ...,  1.1411e-01,\n",
       "            9.7010e-02,  4.5768e-03],\n",
       "          [-2.3739e-01, -7.5663e-02,  1.1077e-01,  ...,  1.6504e-01,\n",
       "            1.7766e-01, -5.6546e-02]],\n",
       "\n",
       "         [[ 1.1200e-01, -1.3254e-01, -3.0316e-02,  ...,  2.8159e-01,\n",
       "           -1.4999e-02,  1.1484e-02],\n",
       "          [ 7.1562e-01, -4.0629e-01, -7.4625e-02,  ...,  2.5144e-01,\n",
       "           -1.6009e-01, -1.3318e-01],\n",
       "          [ 9.6309e-03, -9.8310e-02,  2.0824e-01,  ...,  3.3506e-02,\n",
       "            2.4297e-01, -1.5338e-01],\n",
       "          ...,\n",
       "          [-3.0453e-01,  1.7329e-01,  6.3234e-02,  ...,  2.3181e-02,\n",
       "            7.1577e-02, -1.7428e-01],\n",
       "          [-4.5862e-01, -1.8171e-02,  1.3836e-01,  ...,  7.1593e-02,\n",
       "            1.2114e-01, -2.5528e-01],\n",
       "          [ 1.9562e-02, -3.7647e-01,  2.0142e-01,  ...,  1.6961e-01,\n",
       "            9.9618e-02, -3.1104e-01]]],\n",
       "\n",
       "\n",
       "        ...,\n",
       "\n",
       "\n",
       "        [[[ 1.1000e-01, -1.3887e-01,  1.8193e-01,  ..., -1.5578e-01,\n",
       "            1.9797e-01, -2.2461e-01],\n",
       "          [ 1.9331e-02, -1.3039e-02, -2.2461e-02,  ...,  1.0785e-02,\n",
       "           -5.3731e-02, -3.5540e-01],\n",
       "          [ 1.9442e-01, -1.2250e-01, -8.0762e-02,  ..., -1.0269e-02,\n",
       "           -7.7678e-02, -1.0614e-01],\n",
       "          ...,\n",
       "          [ 6.8548e-04, -6.2502e-02, -5.4517e-02,  ...,  2.5474e-01,\n",
       "            3.0998e-01, -2.5216e-01],\n",
       "          [-6.1673e-02, -8.4825e-02,  7.9229e-02,  ..., -2.6679e-01,\n",
       "            2.5147e-01, -3.3604e-01],\n",
       "          [ 2.9420e-01, -2.6170e-02, -2.2630e-01,  ...,  3.4798e-01,\n",
       "            1.8987e-01, -1.4923e-01]],\n",
       "\n",
       "         [[-2.4474e-01, -1.3679e-01,  2.6613e-01,  ..., -2.7980e-01,\n",
       "            2.3270e-01, -6.1247e-01],\n",
       "          [-6.4875e-01, -4.0271e-02,  3.8571e-01,  ..., -3.5875e-02,\n",
       "           -2.4562e-02, -4.4364e-01],\n",
       "          [-5.8365e-02, -6.4851e-02,  1.9247e-01,  ...,  9.5064e-02,\n",
       "            6.0830e-02, -4.6128e-01],\n",
       "          ...,\n",
       "          [ 1.4693e-01, -3.0425e-01,  1.9674e-01,  ...,  6.1006e-01,\n",
       "            2.1898e-01, -2.9977e-01],\n",
       "          [ 3.7531e-02, -2.3951e-01,  1.9009e-01,  ...,  2.3027e-01,\n",
       "           -2.8928e-02, -1.7565e-01],\n",
       "          [ 2.4334e-01, -2.7141e-01, -1.0846e-02,  ...,  2.0989e-01,\n",
       "           -1.5814e-01, -2.3820e-01]],\n",
       "\n",
       "         [[-2.4774e-01, -7.3479e-02,  1.6108e-01,  ..., -2.3513e-01,\n",
       "           -4.9823e-03, -5.3451e-01],\n",
       "          [-1.5908e-01,  1.0700e-01, -1.5501e-01,  ...,  3.1946e-02,\n",
       "           -1.2981e-01, -1.8137e-01],\n",
       "          [-4.1694e-01,  1.0881e-01, -1.2186e-01,  ...,  2.0156e-01,\n",
       "            1.1056e-01,  7.3046e-02],\n",
       "          ...,\n",
       "          [ 2.1502e-01, -3.1488e-01,  2.2970e-01,  ...,  1.8117e-01,\n",
       "            4.8855e-02, -2.0825e-01],\n",
       "          [ 3.5182e-02, -1.0303e-01, -3.8730e-02,  ...,  1.8812e-01,\n",
       "            1.1054e-02, -2.8664e-01],\n",
       "          [-4.3875e-02, -6.1327e-02,  5.7363e-03,  ...,  1.1783e-01,\n",
       "            9.1426e-02, -2.2747e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-9.6492e-02, -2.4240e-01,  2.1452e-01,  ..., -1.8793e-01,\n",
       "            2.0617e-03, -5.5681e-01],\n",
       "          [-3.1441e-02, -4.0893e-01,  3.9155e-01,  ...,  1.4485e-01,\n",
       "            3.0177e-01, -4.1881e-01],\n",
       "          [-1.1157e-01,  5.7870e-02, -5.5412e-02,  ...,  4.9435e-03,\n",
       "            3.1537e-01, -1.7016e-01],\n",
       "          ...,\n",
       "          [ 2.0077e-01, -1.3087e-01,  1.2083e-01,  ..., -7.9385e-02,\n",
       "            1.7939e-02, -1.2398e-01],\n",
       "          [-1.5954e-01, -2.9012e-01,  1.5970e-01,  ..., -2.8601e-02,\n",
       "            2.3542e-02, -3.4834e-01],\n",
       "          [ 1.4506e-01, -2.6039e-01, -6.8290e-02,  ...,  5.7947e-02,\n",
       "           -1.0257e-02, -2.3320e-01]],\n",
       "\n",
       "         [[-2.5303e-01,  2.9886e-03,  5.0322e-02,  ...,  1.8034e-01,\n",
       "            8.3342e-02, -2.1469e-01],\n",
       "          [ 1.4369e-01, -1.9285e-01,  1.0694e-01,  ...,  9.0475e-02,\n",
       "            2.3242e-02, -1.4495e-01],\n",
       "          [-2.7940e-01, -2.2342e-01,  1.0188e-01,  ...,  8.6000e-03,\n",
       "            4.0350e-01, -1.8810e-01],\n",
       "          ...,\n",
       "          [-3.8778e-01, -2.2073e-01,  4.0704e-01,  ...,  7.9300e-02,\n",
       "            3.3998e-01, -5.3174e-01],\n",
       "          [-2.7106e-01, -2.1997e-01,  1.0765e-01,  ...,  1.7647e-01,\n",
       "            3.0805e-01, -5.4890e-01],\n",
       "          [-2.2935e-01, -1.0717e-01,  1.3206e-01,  ..., -1.1847e-01,\n",
       "            3.4826e-03, -5.0792e-01]],\n",
       "\n",
       "         [[-1.9652e-01, -1.8408e-01,  1.1271e-01,  ...,  1.2052e-01,\n",
       "           -4.7234e-02, -4.8810e-01],\n",
       "          [ 6.3203e-02, -4.2883e-01,  2.9625e-02,  ...,  1.4415e-01,\n",
       "           -1.8366e-02, -2.7604e-01],\n",
       "          [ 7.2458e-02, -1.7369e-01, -1.7014e-01,  ..., -1.7414e-01,\n",
       "           -6.8152e-02, -5.1533e-01],\n",
       "          ...,\n",
       "          [-1.3628e-01, -6.4673e-03,  1.8293e-01,  ...,  1.9630e-03,\n",
       "            9.4225e-02, -3.9900e-01],\n",
       "          [-2.4204e-01, -2.1550e-02,  2.0455e-01,  ..., -1.4410e-01,\n",
       "            5.9946e-02, -2.6284e-01],\n",
       "          [-1.2495e-01, -2.2957e-01,  6.4673e-02,  ..., -1.7613e-02,\n",
       "            9.5474e-02, -1.1746e-01]]],\n",
       "\n",
       "\n",
       "        [[[-6.0992e-01, -2.2933e-01,  2.2621e-01,  ..., -5.4525e-02,\n",
       "            2.2725e-01, -6.0927e-01],\n",
       "          [-1.1847e-01, -2.3678e-01,  4.4462e-02,  ..., -9.7998e-02,\n",
       "           -1.4132e-01, -3.2503e-01],\n",
       "          [ 7.0702e-01, -3.1005e-01, -2.2138e-01,  ...,  3.8034e-01,\n",
       "           -7.9105e-02, -1.5966e-02],\n",
       "          ...,\n",
       "          [ 2.0906e-01, -3.8682e-01,  1.0908e-01,  ...,  1.1938e-01,\n",
       "            4.8154e-02, -4.5999e-01],\n",
       "          [-2.3244e-01, -7.9752e-02, -1.2786e-01,  ...,  4.8567e-02,\n",
       "            3.7135e-01, -8.2101e-02],\n",
       "          [-4.6176e-01,  1.3491e-01,  1.2323e-01,  ..., -1.1110e-01,\n",
       "            3.8294e-02, -6.4207e-01]],\n",
       "\n",
       "         [[ 2.0531e-01, -1.4183e-01,  2.0450e-01,  ..., -3.1012e-02,\n",
       "            8.9918e-02, -2.6481e-01],\n",
       "          [-3.6817e-01, -6.4683e-02,  1.8882e-01,  ...,  1.2663e-01,\n",
       "            5.3977e-02, -3.9028e-01],\n",
       "          [-1.3864e-01, -1.0889e-01,  1.5049e-01,  ...,  3.1534e-02,\n",
       "           -4.0780e-02, -2.3963e-01],\n",
       "          ...,\n",
       "          [-1.0887e-01,  2.1364e-02, -1.3451e-02,  ..., -1.3971e-02,\n",
       "           -5.1014e-02, -2.2664e-01],\n",
       "          [-4.0686e-01, -1.3372e-01,  3.4439e-01,  ...,  1.6152e-01,\n",
       "            1.5443e-01, -3.7471e-01],\n",
       "          [-1.5258e-01, -1.0346e-01,  4.6261e-01,  ..., -1.7810e-01,\n",
       "            1.1142e-01, -4.1518e-01]],\n",
       "\n",
       "         [[ 1.4786e-02,  6.2323e-02,  1.3192e-01,  ..., -3.3450e-01,\n",
       "           -8.4372e-03, -9.7437e-01],\n",
       "          [ 6.1941e-01, -2.5463e-01,  6.0846e-02,  ...,  2.4845e-01,\n",
       "           -1.9653e-01, -2.4436e-01],\n",
       "          [ 6.0945e-01, -5.9619e-01, -3.9144e-01,  ..., -9.1081e-02,\n",
       "            1.9182e-01, -2.9338e-01],\n",
       "          ...,\n",
       "          [-5.6323e-01, -1.8227e-02,  1.3183e-01,  ..., -1.0790e-01,\n",
       "            2.8655e-01, -8.3087e-01],\n",
       "          [-5.1140e-02, -1.2594e-01,  1.5847e-01,  ...,  5.2916e-02,\n",
       "            1.3275e-01, -3.2561e-01],\n",
       "          [-3.6427e-01, -1.1682e-01,  2.9118e-01,  ..., -7.7537e-02,\n",
       "            2.5990e-01, -4.2688e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[-1.3412e-01, -1.0141e-01,  5.5271e-02,  ..., -1.5010e-01,\n",
       "            2.9364e-01, -1.3394e-01],\n",
       "          [-1.3579e-01, -1.3157e-01,  2.0900e-01,  ...,  5.9234e-02,\n",
       "            7.8313e-02, -1.7037e-01],\n",
       "          [-1.0554e-01, -3.1916e-01,  2.3500e-01,  ..., -2.4292e-02,\n",
       "            1.8174e-01, -2.6190e-01],\n",
       "          ...,\n",
       "          [-5.8722e-01, -2.0912e-01,  4.0938e-02,  ...,  1.7767e-01,\n",
       "            8.7382e-02, -4.1229e-01],\n",
       "          [ 1.7826e-01, -3.6874e-01, -5.1746e-04,  ..., -1.3254e-01,\n",
       "           -2.6460e-02, -3.4770e-01],\n",
       "          [ 3.1021e-01, -2.5542e-01,  1.1738e-01,  ...,  2.3144e-01,\n",
       "           -2.8994e-02, -3.1944e-01]],\n",
       "\n",
       "         [[-1.1804e-01,  1.3441e-01, -6.5561e-02,  ...,  7.9366e-02,\n",
       "            1.1020e-01, -1.8083e-01],\n",
       "          [ 1.5017e-01, -1.6917e-01, -1.4039e-01,  ...,  2.8809e-02,\n",
       "            1.8464e-01, -2.1490e-01],\n",
       "          [-6.7731e-02, -4.3276e-01,  3.0015e-02,  ...,  1.3243e-01,\n",
       "            3.4302e-02, -4.6283e-01],\n",
       "          ...,\n",
       "          [ 9.3522e-02, -2.5829e-01, -5.2836e-02,  ...,  9.5823e-02,\n",
       "           -4.6416e-02, -1.7507e-01],\n",
       "          [-1.3411e-01, -3.0066e-01,  3.7297e-01,  ...,  2.6083e-01,\n",
       "            3.0747e-01, -2.7143e-01],\n",
       "          [-9.3217e-02, -1.8691e-01,  2.6598e-02,  ...,  2.0081e-01,\n",
       "            3.6825e-01, -3.9163e-01]],\n",
       "\n",
       "         [[-3.6479e-02, -1.0955e-02, -1.4612e-02,  ..., -2.7851e-01,\n",
       "           -8.9817e-02, -3.9386e-01],\n",
       "          [ 4.1189e-02, -3.3799e-01,  2.1660e-01,  ...,  2.7624e-01,\n",
       "            2.6213e-01, -3.2562e-01],\n",
       "          [ 2.5518e-02, -4.0150e-01,  4.9547e-01,  ...,  2.2614e-01,\n",
       "            2.3152e-01, -7.1973e-01],\n",
       "          ...,\n",
       "          [ 2.7557e-02, -4.9770e-01,  2.1764e-01,  ...,  1.9633e-01,\n",
       "            8.2254e-02, -3.0023e-01],\n",
       "          [ 8.3433e-02, -5.0831e-01, -9.4343e-02,  ...,  1.0361e-01,\n",
       "            2.1635e-01, -3.5144e-01],\n",
       "          [-2.9960e-01, -8.7830e-02,  4.3048e-01,  ..., -9.9437e-02,\n",
       "            1.9518e-02, -5.0016e-01]]],\n",
       "\n",
       "\n",
       "        [[[ 3.7614e-02, -1.0308e-01, -1.0815e-01,  ...,  2.1865e-01,\n",
       "            2.5904e-01, -3.7440e-01],\n",
       "          [ 2.2685e-01, -3.1859e-01,  9.7374e-02,  ...,  7.2647e-03,\n",
       "           -2.1581e-01, -5.8819e-01],\n",
       "          [ 2.3630e-01, -1.1074e-01,  1.0542e-01,  ...,  2.0853e-02,\n",
       "            2.6675e-02, -3.3876e-01],\n",
       "          ...,\n",
       "          [-3.4630e-01,  5.7766e-02,  3.1961e-01,  ..., -6.9655e-03,\n",
       "            1.4667e-01, -4.9123e-01],\n",
       "          [ 2.0190e-01, -2.0498e-01,  5.3031e-02,  ..., -4.7309e-02,\n",
       "            2.0199e-01, -7.4395e-02],\n",
       "          [-4.1928e-01, -2.7520e-01,  2.7043e-01,  ...,  4.3005e-02,\n",
       "            6.5940e-02, -5.4762e-01]],\n",
       "\n",
       "         [[ 2.8005e-02, -1.3932e-01,  1.1770e-01,  ...,  4.7516e-03,\n",
       "            2.0408e-01, -4.6684e-01],\n",
       "          [ 1.5496e-01, -2.2350e-01,  1.0219e-01,  ...,  1.8232e-01,\n",
       "           -8.4636e-03, -8.1784e-02],\n",
       "          [-1.4305e-01, -9.1057e-02, -1.7183e-01,  ...,  5.2050e-02,\n",
       "           -9.8260e-02, -1.5492e-01],\n",
       "          ...,\n",
       "          [-2.8584e-02, -3.3198e-01,  1.8567e-01,  ..., -5.4179e-03,\n",
       "           -1.1556e-01, -2.7175e-01],\n",
       "          [-9.0771e-02, -3.8853e-01,  2.1934e-02,  ...,  1.8532e-01,\n",
       "           -8.4876e-02, -2.4394e-01],\n",
       "          [ 8.6759e-02, -3.4776e-01,  2.0987e-01,  ...,  1.3750e-01,\n",
       "           -1.9692e-02, -3.7351e-01]],\n",
       "\n",
       "         [[-1.2731e-02, -1.3620e-01,  1.0677e-01,  ...,  8.9432e-02,\n",
       "           -6.3019e-02, -2.0007e-01],\n",
       "          [ 1.4012e-01, -3.4520e-01,  2.7179e-03,  ..., -8.4436e-02,\n",
       "           -2.2161e-01, -4.5608e-01],\n",
       "          [-9.1472e-02, -1.3622e-01,  2.3464e-01,  ...,  4.9327e-02,\n",
       "            3.8040e-02, -1.0272e-01],\n",
       "          ...,\n",
       "          [-1.9454e-01, -4.2005e-01,  2.2309e-01,  ...,  1.0346e-02,\n",
       "            1.6404e-01, -2.5865e-01],\n",
       "          [-2.1897e-01, -1.3645e-01,  3.4683e-01,  ..., -2.0970e-02,\n",
       "            2.2262e-01, -4.5473e-01],\n",
       "          [-2.1534e-01, -2.5452e-01,  2.8031e-01,  ..., -1.9222e-02,\n",
       "            9.4498e-02, -7.1534e-01]],\n",
       "\n",
       "         ...,\n",
       "\n",
       "         [[ 6.4578e-04, -3.7023e-01,  1.2392e-01,  ...,  2.2340e-01,\n",
       "            1.4455e-01, -2.4977e-01],\n",
       "          [-4.2229e-03,  9.0808e-02,  8.5946e-03,  ..., -8.4072e-03,\n",
       "           -4.2960e-02, -1.1946e-01],\n",
       "          [ 7.0649e-02, -2.9363e-01,  5.3666e-02,  ...,  3.3862e-01,\n",
       "            1.7118e-01, -5.1617e-01],\n",
       "          ...,\n",
       "          [ 7.9416e-02, -4.4547e-02, -6.6971e-02,  ...,  2.5412e-01,\n",
       "           -2.0677e-01, -1.8849e-02],\n",
       "          [-6.8677e-02, -1.9253e-01, -4.2236e-02,  ...,  2.3251e-01,\n",
       "           -1.1351e-01, -8.0993e-02],\n",
       "          [-2.8029e-01, -1.4126e-01, -8.9227e-03,  ..., -2.2650e-02,\n",
       "           -6.5792e-03, -3.3545e-01]],\n",
       "\n",
       "         [[ 3.5075e-02, -1.2339e-01,  1.5296e-01,  ...,  6.0709e-02,\n",
       "            1.7421e-01, -5.3647e-01],\n",
       "          [ 1.7555e-01, -1.6051e-01,  9.8376e-02,  ...,  1.6234e-01,\n",
       "           -3.1849e-02, -3.1021e-01],\n",
       "          [ 2.3196e-01, -2.5512e-01,  2.5374e-01,  ...,  2.8462e-01,\n",
       "            2.2334e-01, -3.2031e-01],\n",
       "          ...,\n",
       "          [-1.1698e-01,  7.7261e-02,  2.9130e-02,  ...,  1.4626e-01,\n",
       "            1.1713e-01, -1.6120e-01],\n",
       "          [ 4.0070e-01, -4.5602e-01,  2.8910e-01,  ...,  3.9310e-01,\n",
       "            2.7933e-01, -3.5828e-01],\n",
       "          [-3.4040e-01, -3.0837e-01,  2.0349e-01,  ...,  7.5498e-02,\n",
       "            2.3399e-02, -3.4294e-01]],\n",
       "\n",
       "         [[-3.6694e-01, -1.9047e-01,  3.8358e-01,  ..., -5.2632e-02,\n",
       "            9.8086e-02, -5.6276e-01],\n",
       "          [ 4.2076e-01, -3.3349e-01, -7.1157e-02,  ...,  6.9784e-02,\n",
       "           -6.4288e-02, -4.2920e-01],\n",
       "          [-2.0522e-01,  5.4094e-03, -4.8855e-03,  ...,  2.8053e-01,\n",
       "            1.6901e-01, -3.5432e-01],\n",
       "          ...,\n",
       "          [ 3.8053e-02,  1.5825e-01,  7.0238e-02,  ...,  2.7451e-01,\n",
       "            1.5555e-01, -3.3223e-01],\n",
       "          [-2.0286e-02, -2.2492e-01,  3.4508e-01,  ...,  3.6792e-01,\n",
       "            2.7610e-01, -4.9960e-01],\n",
       "          [-2.8043e-01, -2.3797e-01,  1.2524e-01,  ..., -3.2507e-01,\n",
       "           -6.5065e-02, -5.2208e-01]]]], grad_fn=<PermuteBackward0>)"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "mat1 and mat2 shapes cannot be multiplied (100x10 and 64x64)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[50], line 39\u001b[0m\n\u001b[1;32m     36\u001b[0m sample_input \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(\u001b[39m100\u001b[39m, input_dim)\n\u001b[1;32m     38\u001b[0m \u001b[39m# Use the autoencoder to reconstruct input data (no training required)\u001b[39;00m\n\u001b[0;32m---> 39\u001b[0m reconstructed_data \u001b[39m=\u001b[39m autoencoder(sample_input)\n",
      "File \u001b[0;32m~/miniconda3/envs/ssl_ts/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "Cell \u001b[0;32mIn[50], line 25\u001b[0m, in \u001b[0;36mAutoencoder.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x):\n\u001b[1;32m     24\u001b[0m     encoded \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mencoder(x)\n\u001b[0;32m---> 25\u001b[0m     decoded \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdecoder(encoded)\n\u001b[1;32m     26\u001b[0m     \u001b[39mreturn\u001b[39;00m decoded\n",
      "File \u001b[0;32m~/miniconda3/envs/ssl_ts/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/envs/ssl_ts/lib/python3.10/site-packages/torch/nn/modules/container.py:217\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    215\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m):\n\u001b[1;32m    216\u001b[0m     \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[0;32m--> 217\u001b[0m         \u001b[39minput\u001b[39m \u001b[39m=\u001b[39m module(\u001b[39minput\u001b[39;49m)\n\u001b[1;32m    218\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39minput\u001b[39m\n",
      "File \u001b[0;32m~/miniconda3/envs/ssl_ts/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/envs/ssl_ts/lib/python3.10/site-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m     \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mlinear(\u001b[39minput\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (100x10 and 64x64)"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# Define the autoencoder class\n",
    "class Autoencoder(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim):\n",
    "        super(Autoencoder, self).__init__()\n",
    "\n",
    "        # Encoder layers with weights initialized as the identity matrix\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(input_dim, hidden_dim, bias=False),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.encoder[0].weight.data = torch.eye(input_dim)\n",
    "\n",
    "        # Decoder layers with weights initialized as the identity matrix\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(hidden_dim, input_dim, bias=False),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.decoder[0].weight.data = torch.eye(hidden_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        encoded = self.encoder(x)\n",
    "        decoded = self.decoder(encoded)\n",
    "        return decoded\n",
    "\n",
    "# Set the input dimension and hidden dimension\n",
    "input_dim = 10\n",
    "hidden_dim = 64\n",
    "\n",
    "# Create an instance of the autoencoder\n",
    "autoencoder = Autoencoder(input_dim, hidden_dim)\n",
    "\n",
    "# Create sample input data (you should replace this with your actual data)\n",
    "sample_input = torch.randn(100, input_dim)\n",
    "\n",
    "# Use the autoencoder to reconstruct input data (no training required)\n",
    "reconstructed_data = autoencoder(sample_input)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class Autoencoder(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim):\n",
    "        super(Autoencoder, self).__init__()\n",
    "\n",
    "        # Encoder layers with weights initialized as the identity matrix\n",
    "        self.encoder = nn.Linear(input_dim, hidden_dim, bias=False)\n",
    "        self.encoder.weight.data = torch.hstack([-torch.eye(input_dim),\n",
    "                                                 torch.eye(input_dim),\n",
    "                                                 torch.zeros((input_dim,hidden_dim-input_dim*2))]).T\n",
    "\n",
    "        # Decoder layers with weights initialized as the identity matrix\n",
    "        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)\n",
    "        self.decoder.weight.data = -torch.hstack([torch.eye(input_dim),\n",
    "                                                  -torch.eye(input_dim),\n",
    "                                                  torch.zeros((input_dim,hidden_dim-input_dim*2))])\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = self.encoder(x)\n",
    "        z = nn.ReLU(inplace=True)(z)\n",
    "        x_hat = self.decoder(z)\n",
    "        return x_hat\n",
    "\n",
    "input_dim = 6\n",
    "hidden_dim = 64  \n",
    "\n",
    "autoencoder = Autoencoder(input_dim, hidden_dim)\n",
    "\n",
    "x = torch.randn((3, input_dim))\n",
    "x_hat = autoencoder(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1844, -0.0475,  0.4164,  1.0868,  0.0952,  0.0557],\n",
       "        [ 0.7522,  0.9648, -0.9937,  1.4687, -0.7575, -0.1436],\n",
       "        [ 0.8576,  0.7479,  0.1420, -1.5696,  0.4289,  0.2307]])"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1844, -0.0475,  0.4164,  1.0868,  0.0952,  0.0557],\n",
       "        [ 0.7522,  0.9648, -0.9937,  1.4687, -0.7575, -0.1436],\n",
       "        [ 0.8576,  0.7479,  0.1420, -1.5696,  0.4289,  0.2307]],\n",
       "       grad_fn=<MmBackward0>)"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.0174, -1.8656,  0.0601, -2.3853,  1.3686,  0.8493, -0.4814,  0.4057,\n",
       "          0.8693,  0.0657],\n",
       "        [ 2.3877, -0.4045, -1.5180,  1.3057, -0.7241,  0.5446,  1.2187,  0.0744,\n",
       "         -0.2578,  2.1727],\n",
       "        [-0.7112, -1.1036,  0.4630,  0.1597,  0.7562,  0.2722, -1.7186, -0.8981,\n",
       "         -0.8740,  0.7227],\n",
       "        [ 0.8849, -0.6897,  0.0693, -0.4665,  0.8278,  0.8166,  0.5399, -0.6445,\n",
       "          0.7462, -0.6332],\n",
       "        [-0.5154, -0.6035, -2.3274,  0.3267, -0.8014, -0.9588, -2.6186, -0.6371,\n",
       "          0.0490, -0.5791],\n",
       "        [ 2.2823, -0.3598,  0.2948,  0.5487, -0.0978,  2.1875,  2.8801, -0.1620,\n",
       "         -0.3807,  0.5082],\n",
       "        [-0.7348,  1.2022,  0.8836,  2.8562,  1.6879, -0.0264, -1.0877, -0.3618,\n",
       "          0.2822, -1.4223],\n",
       "        [ 0.1404, -0.8036, -0.4324, -0.2638,  2.8903, -1.0265,  0.1243,  0.2342,\n",
       "          0.4538, -1.0774],\n",
       "        [ 0.4438,  0.4353,  0.5972,  0.0999,  0.4556, -1.4620, -0.5464, -0.8039,\n",
       "          0.9624, -2.8401],\n",
       "        [-0.9826, -0.3062, -0.5039,  1.1493, -0.2780, -0.5464, -0.1323,  0.5608,\n",
       "          2.0972,  1.1899]])"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000e+00, 3.9591e-01, 1.4807e-01, 1.5686e-01, 2.4313e-01, 9.6605e-02,\n",
       "         0.0000e+00, 1.5925e-01, 2.4663e-01, 9.9515e-02],\n",
       "        [0.0000e+00, 3.3625e-01, 5.4557e-02, 1.3601e-02, 2.8132e-01, 2.4776e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [0.0000e+00, 6.6838e-01, 3.3385e-01, 2.9290e-01, 6.2592e-01, 2.7627e-01,\n",
       "         0.0000e+00, 8.7257e-01, 4.4312e-01, 0.0000e+00],\n",
       "        [2.0877e-01, 4.0870e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.6888e-01,\n",
       "         0.0000e+00, 3.4732e-02, 0.0000e+00, 0.0000e+00],\n",
       "        [2.0471e-01, 6.4470e-02, 0.0000e+00, 2.5423e-01, 0.0000e+00, 2.8507e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 2.3009e-01],\n",
       "        [0.0000e+00, 5.0830e-01, 2.4382e-01, 1.6457e-01, 0.0000e+00, 5.2556e-01,\n",
       "         0.0000e+00, 5.3050e-02, 0.0000e+00, 0.0000e+00],\n",
       "        [0.0000e+00, 6.8208e-01, 1.1152e-01, 0.0000e+00, 6.4832e-02, 5.2638e-01,\n",
       "         0.0000e+00, 9.6160e-02, 4.9891e-01, 6.2357e-03],\n",
       "        [3.9628e-01, 1.0366e-01, 0.0000e+00, 3.0479e-02, 2.5999e-01, 2.4805e-02,\n",
       "         0.0000e+00, 1.2886e-01, 2.0998e-01, 1.2239e-01],\n",
       "        [0.0000e+00, 2.3133e-01, 0.0000e+00, 0.0000e+00, 4.3891e-01, 4.2699e-01,\n",
       "         0.0000e+00, 2.1379e-01, 6.0083e-01, 5.9468e-02],\n",
       "        [2.1559e-01, 4.6737e-01, 1.7519e-01, 0.0000e+00, 0.0000e+00, 5.7920e-01,\n",
       "         0.0000e+00, 5.3929e-01, 0.0000e+00, 0.0000e+00],\n",
       "        [0.0000e+00, 2.6796e-01, 5.7224e-02, 6.2386e-02, 9.7306e-02, 3.4419e-01,\n",
       "         0.0000e+00, 0.0000e+00, 5.5747e-02, 0.0000e+00],\n",
       "        [2.3101e-01, 1.1006e-01, 5.2695e-02, 1.0797e-01, 1.0392e-01, 1.3687e-01,\n",
       "         0.0000e+00, 0.0000e+00, 1.2937e-02, 1.6651e-01],\n",
       "        [5.9844e-03, 3.1540e-01, 2.9647e-01, 0.0000e+00, 2.1410e-01, 3.1443e-01,\n",
       "         0.0000e+00, 5.1859e-04, 0.0000e+00, 1.6971e-01],\n",
       "        [5.5179e-02, 2.8599e-01, 0.0000e+00, 1.9867e-01, 1.3964e-01, 1.0541e-01,\n",
       "         4.9677e-02, 9.2282e-02, 1.2462e-01, 0.0000e+00],\n",
       "        [1.8987e-02, 5.2079e-01, 0.0000e+00, 3.6586e-01, 2.0302e-01, 2.2495e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 2.8215e-01],\n",
       "        [3.7433e-02, 1.6663e-01, 0.0000e+00, 7.3768e-02, 1.6678e-01, 1.1504e-01,\n",
       "         0.0000e+00, 0.0000e+00, 5.5199e-02, 0.0000e+00],\n",
       "        [0.0000e+00, 3.4356e-02, 5.4194e-01, 0.0000e+00, 6.7876e-01, 3.2437e-01,\n",
       "         0.0000e+00, 0.0000e+00, 3.9391e-01, 0.0000e+00],\n",
       "        [1.4226e-01, 6.3041e-01, 0.0000e+00, 1.1469e-01, 3.8210e-01, 1.3733e-01,\n",
       "         8.3534e-02, 0.0000e+00, 3.6119e-01, 2.7212e-01],\n",
       "        [1.2761e-01, 4.6993e-01, 0.0000e+00, 0.0000e+00, 1.8878e-01, 4.3295e-01,\n",
       "         2.6565e-02, 0.0000e+00, 3.4305e-01, 0.0000e+00],\n",
       "        [3.8772e-02, 4.2946e-01, 0.0000e+00, 7.5570e-02, 1.7849e-01, 1.5083e-01,\n",
       "         0.0000e+00, 0.0000e+00, 1.7091e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 3.8877e-01, 4.6267e-02, 0.0000e+00, 5.8135e-01, 2.8509e-01,\n",
       "         0.0000e+00, 2.0605e-01, 3.0130e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 4.9308e-01, 3.9984e-01, 0.0000e+00, 1.9657e-01, 4.5430e-01,\n",
       "         0.0000e+00, 0.0000e+00, 6.3160e-02, 0.0000e+00],\n",
       "        [1.6261e-01, 4.3525e-01, 1.2132e-01, 0.0000e+00, 1.6856e-02, 3.7696e-01,\n",
       "         0.0000e+00, 1.6251e-01, 1.7227e-01, 0.0000e+00],\n",
       "        [9.3133e-02, 5.4613e-01, 3.0942e-01, 2.2668e-01, 4.0545e-01, 1.0144e-01,\n",
       "         0.0000e+00, 0.0000e+00, 2.1673e-01, 0.0000e+00],\n",
       "        [1.2149e-01, 5.5121e-01, 3.2091e-01, 0.0000e+00, 3.1151e-01, 3.5131e-01,\n",
       "         0.0000e+00, 1.1528e-01, 3.4622e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 7.5393e-01, 0.0000e+00, 0.0000e+00, 4.7593e-01, 9.7482e-02,\n",
       "         0.0000e+00, 1.1488e-01, 6.3108e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 2.2809e-01, 9.3735e-01, 0.0000e+00, 4.6998e-01, 6.2950e-01,\n",
       "         0.0000e+00, 8.5188e-02, 0.0000e+00, 0.0000e+00],\n",
       "        [0.0000e+00, 4.0206e-01, 4.2391e-01, 0.0000e+00, 2.7130e-01, 3.3686e-01,\n",
       "         0.0000e+00, 1.4701e-01, 1.5584e-01, 1.1492e-01],\n",
       "        [0.0000e+00, 5.1483e-01, 3.4618e-01, 4.1662e-02, 8.9025e-02, 3.4880e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [1.9037e-01, 3.3347e-01, 1.6915e-01, 0.0000e+00, 0.0000e+00, 3.4225e-01,\n",
       "         0.0000e+00, 4.4682e-01, 7.7623e-03, 0.0000e+00],\n",
       "        [0.0000e+00, 4.2785e-01, 5.7199e-01, 1.3387e-02, 8.2213e-02, 6.2310e-01,\n",
       "         0.0000e+00, 1.2942e-01, 0.0000e+00, 0.0000e+00],\n",
       "        [2.2047e-01, 1.5001e-01, 0.0000e+00, 0.0000e+00, 1.4740e-01, 2.7857e-01,\n",
       "         0.0000e+00, 4.9126e-02, 2.6461e-01, 0.0000e+00],\n",
       "        [2.2033e-01, 2.2283e-01, 0.0000e+00, 2.0852e-01, 0.0000e+00, 1.5015e-01,\n",
       "         0.0000e+00, 6.6880e-02, 0.0000e+00, 0.0000e+00],\n",
       "        [2.7092e-02, 4.8072e-01, 3.4202e-01, 0.0000e+00, 4.2588e-01, 2.3158e-01,\n",
       "         0.0000e+00, 0.0000e+00, 2.8641e-01, 0.0000e+00],\n",
       "        [8.9448e-03, 9.4627e-02, 0.0000e+00, 9.6618e-02, 1.1452e-01, 1.2615e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [0.0000e+00, 5.0403e-01, 0.0000e+00, 1.1154e-01, 6.2038e-01, 4.1293e-01,\n",
       "         0.0000e+00, 3.4794e-01, 5.7102e-01, 1.7239e-01],\n",
       "        [3.4337e-01, 2.6681e-01, 2.0919e-02, 0.0000e+00, 4.3048e-01, 5.1583e-02,\n",
       "         1.0806e-01, 0.0000e+00, 2.7871e-01, 2.0296e-02],\n",
       "        [9.5720e-02, 2.9626e-02, 5.4899e-02, 0.0000e+00, 3.4316e-02, 3.9052e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 4.3110e-04],\n",
       "        [0.0000e+00, 2.5689e-01, 0.0000e+00, 5.7398e-02, 2.0332e-03, 3.6419e-01,\n",
       "         0.0000e+00, 3.0988e-01, 2.0275e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 3.1490e-01, 1.4708e-01, 1.1704e-01, 4.8266e-02, 2.1294e-01,\n",
       "         0.0000e+00, 8.8372e-02, 0.0000e+00, 0.0000e+00],\n",
       "        [0.0000e+00, 3.1821e-01, 4.4715e-01, 2.2968e-02, 1.8420e-01, 3.4193e-01,\n",
       "         0.0000e+00, 1.1105e-01, 0.0000e+00, 0.0000e+00],\n",
       "        [2.3009e-01, 5.7465e-01, 1.7543e-01, 0.0000e+00, 0.0000e+00, 4.9222e-01,\n",
       "         0.0000e+00, 4.3074e-01, 1.2506e-01, 0.0000e+00],\n",
       "        [9.8978e-02, 3.9067e-01, 8.2960e-02, 0.0000e+00, 4.4113e-01, 3.8719e-01,\n",
       "         0.0000e+00, 0.0000e+00, 2.3447e-01, 6.8928e-02],\n",
       "        [0.0000e+00, 2.8977e-01, 1.4735e-01, 2.7550e-02, 2.1846e-01, 9.5139e-02,\n",
       "         0.0000e+00, 0.0000e+00, 7.5183e-02, 0.0000e+00],\n",
       "        [1.3971e-01, 6.8105e-01, 3.3828e-01, 1.0309e-01, 3.1895e-01, 1.5120e-01,\n",
       "         0.0000e+00, 0.0000e+00, 4.5697e-01, 0.0000e+00],\n",
       "        [2.8975e-02, 4.7176e-01, 2.4516e-01, 0.0000e+00, 6.0244e-02, 5.4450e-01,\n",
       "         0.0000e+00, 1.2406e-01, 2.7828e-02, 0.0000e+00],\n",
       "        [3.3371e-03, 4.3411e-01, 2.9472e-03, 1.7819e-01, 1.8070e-01, 0.0000e+00,\n",
       "         0.0000e+00, 2.3921e-02, 2.1053e-01, 7.6531e-02],\n",
       "        [6.1931e-02, 5.0865e-01, 2.4989e-01, 0.0000e+00, 3.3470e-01, 4.8018e-01,\n",
       "         0.0000e+00, 5.3618e-01, 5.8489e-01, 8.3584e-03],\n",
       "        [0.0000e+00, 3.6069e-01, 0.0000e+00, 1.1092e-01, 1.4939e-01, 2.8229e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0426e-01],\n",
       "        [0.0000e+00, 2.3010e-01, 4.1897e-01, 3.4832e-02, 3.6025e-01, 2.6620e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [0.0000e+00, 2.9854e-01, 1.3991e-01, 0.0000e+00, 2.9783e-01, 3.6763e-01,\n",
       "         0.0000e+00, 4.8179e-01, 4.1527e-01, 0.0000e+00],\n",
       "        [2.3519e-02, 3.4104e-01, 1.9821e-01, 1.9365e-01, 3.6810e-01, 1.6408e-01,\n",
       "         0.0000e+00, 2.6606e-01, 2.3349e-02, 0.0000e+00],\n",
       "        [0.0000e+00, 4.9590e-01, 1.4199e-01, 1.3245e-01, 8.1167e-02, 3.1311e-01,\n",
       "         0.0000e+00, 6.7402e-01, 5.0006e-01, 5.3113e-02],\n",
       "        [6.2544e-02, 3.1966e-01, 2.4880e-01, 1.1557e-01, 2.3747e-01, 2.5974e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [1.0316e-01, 2.0927e-01, 0.0000e+00, 0.0000e+00, 3.1005e-01, 2.1844e-01,\n",
       "         3.8177e-02, 3.5363e-03, 3.1193e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 4.2135e-01, 5.6370e-01, 0.0000e+00, 4.2477e-01, 2.3090e-01,\n",
       "         0.0000e+00, 3.9542e-01, 3.4099e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 3.3781e-01, 1.7804e-01, 0.0000e+00, 9.7886e-01, 4.8793e-01,\n",
       "         0.0000e+00, 3.5065e-01, 7.2910e-01, 0.0000e+00],\n",
       "        [4.4947e-02, 6.1535e-01, 0.0000e+00, 2.7918e-01, 2.5570e-01, 7.2054e-02,\n",
       "         0.0000e+00, 0.0000e+00, 1.7005e-01, 0.0000e+00],\n",
       "        [1.6665e-01, 1.0299e-01, 4.1741e-02, 0.0000e+00, 2.4670e-02, 3.8523e-02,\n",
       "         0.0000e+00, 2.0848e-01, 1.5692e-01, 1.1339e-01],\n",
       "        [6.0387e-02, 4.9885e-01, 1.2727e-01, 0.0000e+00, 2.4759e-01, 4.1431e-01,\n",
       "         0.0000e+00, 2.9514e-01, 7.7364e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 6.5757e-01, 4.8634e-01, 1.3966e-01, 4.2882e-01, 4.1778e-01,\n",
       "         0.0000e+00, 7.2470e-02, 3.5444e-01, 7.6412e-02],\n",
       "        [3.6030e-01, 7.1012e-01, 0.0000e+00, 0.0000e+00, 1.9758e-01, 3.5305e-01,\n",
       "         0.0000e+00, 1.4980e-01, 6.1264e-01, 8.5748e-02],\n",
       "        [1.0509e-01, 6.8702e-01, 2.8665e-01, 7.7752e-02, 3.8244e-01, 1.3337e-01,\n",
       "         2.4752e-01, 1.6085e-01, 6.0254e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 4.4895e-01, 3.5113e-01, 1.2002e-01, 4.4492e-01, 2.8897e-01,\n",
       "         0.0000e+00, 3.6285e-01, 0.0000e+00, 0.0000e+00],\n",
       "        [1.6177e-01, 5.6407e-01, 1.2024e-01, 1.4285e-01, 2.3327e-01, 9.0886e-02,\n",
       "         1.0369e-01, 2.0634e-01, 4.7996e-01, 2.1300e-02],\n",
       "        [0.0000e+00, 5.1389e-01, 4.9136e-02, 4.0287e-02, 3.6974e-01, 5.0238e-01,\n",
       "         0.0000e+00, 0.0000e+00, 3.6704e-01, 7.3841e-02],\n",
       "        [5.2548e-02, 7.9984e-01, 3.3096e-01, 1.3503e-01, 5.1144e-01, 6.4635e-02,\n",
       "         0.0000e+00, 2.2562e-01, 6.4548e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 4.5841e-01, 6.1658e-01, 1.1580e-01, 3.5932e-01, 4.4292e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0272e-01],\n",
       "        [0.0000e+00, 3.1601e-01, 2.7057e-01, 0.0000e+00, 2.3781e-01, 4.9598e-01,\n",
       "         0.0000e+00, 1.7333e-01, 2.6259e-01, 0.0000e+00],\n",
       "        [2.1474e-01, 1.0333e-01, 0.0000e+00, 1.2585e-01, 0.0000e+00, 4.6769e-02,\n",
       "         0.0000e+00, 1.3840e-01, 0.0000e+00, 1.2265e-01],\n",
       "        [0.0000e+00, 4.5410e-01, 3.2946e-02, 2.8125e-01, 3.1215e-01, 4.0317e-01,\n",
       "         0.0000e+00, 0.0000e+00, 8.0310e-02, 1.7801e-01],\n",
       "        [2.2958e-01, 5.7088e-01, 1.0311e-01, 0.0000e+00, 1.9533e-02, 7.2333e-01,\n",
       "         0.0000e+00, 2.0296e-01, 7.5823e-02, 0.0000e+00],\n",
       "        [3.9312e-02, 4.2880e-01, 3.7651e-01, 1.5567e-01, 2.3564e-01, 3.1402e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [0.0000e+00, 5.3335e-01, 5.0086e-01, 0.0000e+00, 1.2261e-01, 4.6135e-01,\n",
       "         0.0000e+00, 1.4445e-01, 4.3281e-02, 0.0000e+00],\n",
       "        [1.6574e-01, 2.4939e-01, 0.0000e+00, 2.9213e-01, 0.0000e+00, 2.2121e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [5.1534e-01, 3.3479e-01, 0.0000e+00, 3.1898e-01, 3.3459e-01, 1.7649e-02,\n",
       "         0.0000e+00, 0.0000e+00, 2.7494e-01, 3.5351e-01],\n",
       "        [6.8270e-02, 6.6574e-01, 8.1116e-02, 0.0000e+00, 3.6191e-01, 3.8809e-01,\n",
       "         0.0000e+00, 2.6190e-01, 7.3165e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 8.8263e-01, 3.1771e-01, 0.0000e+00, 7.2857e-01, 7.1344e-01,\n",
       "         0.0000e+00, 2.4067e-01, 1.0766e+00, 0.0000e+00],\n",
       "        [0.0000e+00, 4.6826e-01, 2.6164e-01, 2.2408e-01, 1.6630e-01, 1.9927e-01,\n",
       "         0.0000e+00, 3.3738e-01, 2.2124e-01, 1.9941e-01],\n",
       "        [0.0000e+00, 1.1947e-01, 6.1582e-01, 0.0000e+00, 4.2244e-01, 2.0968e-01,\n",
       "         0.0000e+00, 1.8471e-01, 1.6855e-01, 2.7517e-02],\n",
       "        [0.0000e+00, 2.6528e-01, 4.3205e-02, 1.0462e-01, 1.1106e-01, 4.7919e-01,\n",
       "         0.0000e+00, 0.0000e+00, 4.4715e-02, 9.9195e-02],\n",
       "        [3.6042e-01, 4.2523e-01, 0.0000e+00, 2.8737e-01, 2.4873e-01, 1.9856e-01,\n",
       "         1.2695e-01, 0.0000e+00, 2.8523e-01, 0.0000e+00],\n",
       "        [5.7771e-01, 2.8200e-01, 2.4695e-02, 5.0210e-02, 0.0000e+00, 2.0174e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 3.2632e-01],\n",
       "        [0.0000e+00, 4.4611e-01, 3.1356e-01, 0.0000e+00, 4.4132e-01, 2.5638e-01,\n",
       "         0.0000e+00, 0.0000e+00, 3.5552e-01, 2.8202e-02],\n",
       "        [2.5990e-03, 9.2232e-01, 0.0000e+00, 0.0000e+00, 1.0586e-01, 6.0660e-01,\n",
       "         0.0000e+00, 1.6886e-01, 6.3656e-01, 4.8919e-02],\n",
       "        [0.0000e+00, 3.2084e-01, 0.0000e+00, 2.0025e-02, 6.3495e-03, 4.6816e-01,\n",
       "         0.0000e+00, 1.7421e-02, 1.1906e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 4.2457e-01, 0.0000e+00, 2.6380e-01, 2.7084e-01, 2.9049e-01,\n",
       "         0.0000e+00, 4.1490e-01, 2.5032e-01, 1.6869e-01],\n",
       "        [0.0000e+00, 3.4244e-01, 7.8211e-01, 0.0000e+00, 7.5249e-01, 5.9258e-01,\n",
       "         0.0000e+00, 5.0380e-01, 6.5882e-01, 1.7082e-01],\n",
       "        [9.7293e-02, 3.5476e-01, 3.5031e-01, 1.2627e-01, 1.9158e-01, 3.4014e-01,\n",
       "         0.0000e+00, 2.3014e-02, 0.0000e+00, 0.0000e+00],\n",
       "        [9.8005e-02, 5.3430e-01, 9.6826e-02, 2.3291e-01, 3.8750e-01, 2.2138e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [1.2784e-01, 1.4535e-01, 0.0000e+00, 0.0000e+00, 4.4076e-01, 3.0036e-01,\n",
       "         3.4917e-02, 2.1575e-01, 4.7737e-01, 1.7738e-01],\n",
       "        [0.0000e+00, 5.0352e-01, 6.8305e-01, 1.6319e-01, 4.4097e-01, 4.6093e-01,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 5.2586e-02],\n",
       "        [1.8718e-01, 1.6083e-01, 0.0000e+00, 4.4001e-02, 3.6696e-01, 1.2639e-01,\n",
       "         0.0000e+00, 4.2252e-01, 5.9222e-01, 4.7654e-02],\n",
       "        [6.4173e-02, 2.4301e-01, 0.0000e+00, 5.0171e-02, 0.0000e+00, 6.6937e-02,\n",
       "         0.0000e+00, 2.1519e-01, 0.0000e+00, 2.9367e-01],\n",
       "        [1.7360e-01, 4.3262e-01, 9.9577e-02, 2.1859e-02, 5.1283e-01, 2.0299e-01,\n",
       "         0.0000e+00, 4.3683e-01, 4.3393e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 7.2570e-01, 6.5482e-01, 0.0000e+00, 0.0000e+00, 8.7435e-01,\n",
       "         0.0000e+00, 1.2932e-01, 0.0000e+00, 1.9342e-02],\n",
       "        [8.2455e-03, 4.8427e-01, 0.0000e+00, 0.0000e+00, 5.5801e-01, 5.8107e-01,\n",
       "         2.3393e-02, 0.0000e+00, 6.2686e-01, 0.0000e+00],\n",
       "        [0.0000e+00, 3.0481e-01, 3.1452e-01, 2.6345e-02, 2.4294e-02, 4.2178e-01,\n",
       "         0.0000e+00, 1.4916e-01, 3.0435e-02, 0.0000e+00],\n",
       "        [0.0000e+00, 2.5579e-01, 1.3658e-01, 6.2642e-02, 3.1601e-01, 3.8670e-01,\n",
       "         0.0000e+00, 4.0626e-02, 1.1173e-01, 1.3773e-03],\n",
       "        [1.7890e-01, 3.0938e-01, 1.5988e-01, 1.2571e-01, 0.0000e+00, 7.4686e-02,\n",
       "         0.0000e+00, 2.7499e-01, 4.8048e-02, 7.3613e-02]],\n",
       "       grad_fn=<ReluBackward0>)"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reconstructed_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
