{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "MLP_S = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/etth1/XY_ablation_FC2_O_OX/based_model/patchtst_sim_pretrained_D32_cw512_patch12_stride12_epochs-pretrain100_mask0.5_model1_no_permute/patchtst_sim_pretrained_D32_cw512_patch12_stride12_epochs-pretrain100_mask0.5_model1_no_permute_30.pth'\n",
    "MLP_M = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/etth1/XY_ablation_FC2_O_OX/based_model/patchtst_sim_pretrained_D64_cw512_patch12_stride12_epochs-pretrain100_mask0.5_model1_no_permute/patchtst_sim_pretrained_D64_cw512_patch12_stride12_epochs-pretrain100_mask0.5_model1_no_permute_20.pth'\n",
    "MLP_L = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/etth1/XY_ablation_FC2_O_OX/based_model/patchtst_sim_pretrained_D128_cw512_patch12_stride12_epochs-pretrain100_mask0.5_model1_no_permute/patchtst_sim_pretrained_D128_cw512_patch12_stride12_epochs-pretrain100_mask0.5_model1_no_permute_30.pth'\n",
    "TRANS_S = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/etth1/XY_ablation_Transformer_O_OX/based_model/patchtst_sim_pretrained_D128_cw512_patch12_stride12_epochs-pretrain100_mask0.5_model1_no_permute/patchtst_sim_pretrained_D128_cw512_patch12_stride12_epochs-pretrain100_mask0.5_model1_no_permute_70.pth'\n",
    "TRANS_L = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/ettm1/XY_ablation_Transformer_O_O/based_model/patchtst_sim_pretrained_D128_cw512_patch12_stride12_epochs-pretrain100_mask0.5_model1_no_permute/patchtst_sim_pretrained_D128_cw512_patch12_stride12_epochs-pretrain100_mask0.5_model1_no_permute_90.pth'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "weights_MLP_S = torch.load(MLP_S, map_location='cpu')\n",
    "weights_MLP_M = torch.load(MLP_M, map_location='cpu')\n",
    "weights_MLP_L = torch.load(MLP_L, map_location='cpu')\n",
    "weights_TRANS_S = torch.load(TRANS_S, map_location='cpu')\n",
    "weights_TRANS_L = torch.load(TRANS_L, map_location='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "list index out of range",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mlist\u001b[39;49m(weights_MLP_S\u001b[39m.\u001b[39;49mkeys())[\u001b[39m9\u001b[39;49m]\n",
      "\u001b[0;31mIndexError\u001b[0m: list index out of range"
     ]
    }
   ],
   "source": [
    "list(weights_MLP_S.keys())[9]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3536)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weights_TRANS_L['backbone.encoder.layers.0.self_attn.sdp_attn.scale']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prettytable import PrettyTable\n",
    "\n",
    "def count_parameters(model):\n",
    "    table = PrettyTable([\"Modules\", \"Parameters\"])\n",
    "    total_params = 0\n",
    "    for name, parameter in model.named_parameters():\n",
    "        if 'backbone' in name:\n",
    "            if not parameter.requires_grad:\n",
    "                continue\n",
    "            params = parameter.numel()\n",
    "            table.add_row([name, params])\n",
    "            total_params += params\n",
    "    print(table)\n",
    "    print(f\"Total Trainable Params: {total_params}\")\n",
    "    return total_params\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "backbone.W_pos\n",
      "torch.Size([42, 128])\n",
      "backbone.W_P.weight\n",
      "torch.Size([128, 12])\n",
      "backbone.W_P.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.self_attn.W_Q.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.0.self_attn.W_Q.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.self_attn.W_K.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.0.self_attn.W_K.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.self_attn.W_V.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.0.self_attn.W_V.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.self_attn.sdp_attn.scale\n",
      "torch.Size([])\n",
      "backbone.encoder.layers.0.self_attn.to_out.0.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.0.self_attn.to_out.0.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.norm_attn.1.weight\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.norm_attn.1.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.norm_attn.1.running_mean\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.norm_attn.1.running_var\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.norm_attn.1.num_batches_tracked\n",
      "torch.Size([])\n",
      "backbone.encoder.layers.0.ff.0.weight\n",
      "torch.Size([256, 128])\n",
      "backbone.encoder.layers.0.ff.0.bias\n",
      "torch.Size([256])\n",
      "backbone.encoder.layers.0.ff.3.weight\n",
      "torch.Size([128, 256])\n",
      "backbone.encoder.layers.0.ff.3.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.norm_ffn.1.weight\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.norm_ffn.1.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.norm_ffn.1.running_mean\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.norm_ffn.1.running_var\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.0.norm_ffn.1.num_batches_tracked\n",
      "torch.Size([])\n",
      "backbone.encoder.layers.1.self_attn.W_Q.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.1.self_attn.W_Q.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.self_attn.W_K.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.1.self_attn.W_K.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.self_attn.W_V.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.1.self_attn.W_V.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.self_attn.sdp_attn.scale\n",
      "torch.Size([])\n",
      "backbone.encoder.layers.1.self_attn.to_out.0.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.1.self_attn.to_out.0.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.norm_attn.1.weight\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.norm_attn.1.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.norm_attn.1.running_mean\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.norm_attn.1.running_var\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.norm_attn.1.num_batches_tracked\n",
      "torch.Size([])\n",
      "backbone.encoder.layers.1.ff.0.weight\n",
      "torch.Size([256, 128])\n",
      "backbone.encoder.layers.1.ff.0.bias\n",
      "torch.Size([256])\n",
      "backbone.encoder.layers.1.ff.3.weight\n",
      "torch.Size([128, 256])\n",
      "backbone.encoder.layers.1.ff.3.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.norm_ffn.1.weight\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.norm_ffn.1.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.norm_ffn.1.running_mean\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.norm_ffn.1.running_var\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.1.norm_ffn.1.num_batches_tracked\n",
      "torch.Size([])\n",
      "backbone.encoder.layers.2.self_attn.W_Q.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.2.self_attn.W_Q.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.self_attn.W_K.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.2.self_attn.W_K.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.self_attn.W_V.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.2.self_attn.W_V.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.self_attn.sdp_attn.scale\n",
      "torch.Size([])\n",
      "backbone.encoder.layers.2.self_attn.to_out.0.weight\n",
      "torch.Size([128, 128])\n",
      "backbone.encoder.layers.2.self_attn.to_out.0.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.norm_attn.1.weight\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.norm_attn.1.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.norm_attn.1.running_mean\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.norm_attn.1.running_var\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.norm_attn.1.num_batches_tracked\n",
      "torch.Size([])\n",
      "backbone.encoder.layers.2.ff.0.weight\n",
      "torch.Size([256, 128])\n",
      "backbone.encoder.layers.2.ff.0.bias\n",
      "torch.Size([256])\n",
      "backbone.encoder.layers.2.ff.3.weight\n",
      "torch.Size([128, 256])\n",
      "backbone.encoder.layers.2.ff.3.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.norm_ffn.1.weight\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.norm_ffn.1.bias\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.norm_ffn.1.running_mean\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.norm_ffn.1.running_var\n",
      "torch.Size([128])\n",
      "backbone.encoder.layers.2.norm_ffn.1.num_batches_tracked\n",
      "torch.Size([])\n",
      "head.linear.weight\n",
      "torch.Size([12, 128])\n",
      "head.linear.bias\n",
      "torch.Size([12])\n"
     ]
    }
   ],
   "source": [
    "for k in weights_TRANS_L.keys():\n",
    "    print(k)\n",
    "    print(weights_TRANS_L[k].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchinfo import summary\n",
    "\n",
    "#model = ConvNet()\n",
    "#batch_size = 16\n",
    "#summary(model, input_size=(batch_size, 1, 28, 28))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models.ablation_model_FC import PatchTST_sim as PatchTST_FC\n",
    "from src.models.ablation_model_FC2 import PatchTST_sim as PatchTST_FC2\n",
    "from src.models.ablation_model_Transformer import PatchTST_sim as PatchTST_Trans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================\n",
      "+---------------------+------------+\n",
      "|       Modules       | Parameters |\n",
      "+---------------------+------------+\n",
      "| backbone.W_P.weight |    384     |\n",
      "|  backbone.W_P.bias  |     32     |\n",
      "+---------------------+------------+\n",
      "Total Trainable Params: 416\n",
      "=================\n",
      "+---------------------+------------+\n",
      "|       Modules       | Parameters |\n",
      "+---------------------+------------+\n",
      "| backbone.W_P.weight |    768     |\n",
      "|  backbone.W_P.bias  |     64     |\n",
      "+---------------------+------------+\n",
      "Total Trainable Params: 832\n",
      "=================\n",
      "+---------------------+------------+\n",
      "|       Modules       | Parameters |\n",
      "+---------------------+------------+\n",
      "| backbone.W_P.weight |    1536    |\n",
      "|  backbone.W_P.bias  |    128     |\n",
      "+---------------------+------------+\n",
      "Total Trainable Params: 1664\n"
     ]
    }
   ],
   "source": [
    "for d in [32,64,128]:\n",
    "    print('=================')\n",
    "    model_FC = PatchTST_FC(c_in=7,\n",
    "                        target_dim=96,\n",
    "                        patch_len=12,\n",
    "                        stride=12,\n",
    "                        d_model=d,\n",
    "                        num_patch=42)\n",
    "    count_parameters(model_FC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================\n",
      "+----------------------+------------+\n",
      "|       Modules        | Parameters |\n",
      "+----------------------+------------+\n",
      "| backbone.W_P1.weight |    384     |\n",
      "|  backbone.W_P1.bias  |     32     |\n",
      "| backbone.W_P2.weight |    1024    |\n",
      "|  backbone.W_P2.bias  |     32     |\n",
      "+----------------------+------------+\n",
      "Total Trainable Params: 1472\n",
      "=================\n",
      "+----------------------+------------+\n",
      "|       Modules        | Parameters |\n",
      "+----------------------+------------+\n",
      "| backbone.W_P1.weight |    768     |\n",
      "|  backbone.W_P1.bias  |     64     |\n",
      "| backbone.W_P2.weight |    4096    |\n",
      "|  backbone.W_P2.bias  |     64     |\n",
      "+----------------------+------------+\n",
      "Total Trainable Params: 4992\n",
      "=================\n",
      "+----------------------+------------+\n",
      "|       Modules        | Parameters |\n",
      "+----------------------+------------+\n",
      "| backbone.W_P1.weight |    1536    |\n",
      "|  backbone.W_P1.bias  |    128     |\n",
      "| backbone.W_P2.weight |   16384    |\n",
      "|  backbone.W_P2.bias  |    128     |\n",
      "+----------------------+------------+\n",
      "Total Trainable Params: 18176\n"
     ]
    }
   ],
   "source": [
    "for d in [32,64,128]:\n",
    "    print('=================')\n",
    "    model_FC2 = PatchTST_FC2(c_in=7,\n",
    "                        target_dim=96,\n",
    "                        patch_len=12,\n",
    "                        stride=12,\n",
    "                        d_model=d,\n",
    "                        num_patch=42)\n",
    "    count_parameters(model_FC2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------------------------------------------+------------+\n",
      "|                       Modules                       | Parameters |\n",
      "+-----------------------------------------------------+------------+\n",
      "|                    backbone.W_pos                   |    672     |\n",
      "|                 backbone.W_P.weight                 |    192     |\n",
      "|                  backbone.W_P.bias                  |     16     |\n",
      "|    backbone.encoder.layers.0.self_attn.W_Q.weight   |    256     |\n",
      "|     backbone.encoder.layers.0.self_attn.W_Q.bias    |     16     |\n",
      "|    backbone.encoder.layers.0.self_attn.W_K.weight   |    256     |\n",
      "|     backbone.encoder.layers.0.self_attn.W_K.bias    |     16     |\n",
      "|    backbone.encoder.layers.0.self_attn.W_V.weight   |    256     |\n",
      "|     backbone.encoder.layers.0.self_attn.W_V.bias    |     16     |\n",
      "| backbone.encoder.layers.0.self_attn.to_out.0.weight |    256     |\n",
      "|  backbone.encoder.layers.0.self_attn.to_out.0.bias  |     16     |\n",
      "|     backbone.encoder.layers.0.norm_attn.1.weight    |     16     |\n",
      "|      backbone.encoder.layers.0.norm_attn.1.bias     |     16     |\n",
      "|        backbone.encoder.layers.0.ff.0.weight        |    2048    |\n",
      "|         backbone.encoder.layers.0.ff.0.bias         |    128     |\n",
      "|        backbone.encoder.layers.0.ff.3.weight        |    2048    |\n",
      "|         backbone.encoder.layers.0.ff.3.bias         |     16     |\n",
      "|     backbone.encoder.layers.0.norm_ffn.1.weight     |     16     |\n",
      "|      backbone.encoder.layers.0.norm_ffn.1.bias      |     16     |\n",
      "|    backbone.encoder.layers.1.self_attn.W_Q.weight   |    256     |\n",
      "|     backbone.encoder.layers.1.self_attn.W_Q.bias    |     16     |\n",
      "|    backbone.encoder.layers.1.self_attn.W_K.weight   |    256     |\n",
      "|     backbone.encoder.layers.1.self_attn.W_K.bias    |     16     |\n",
      "|    backbone.encoder.layers.1.self_attn.W_V.weight   |    256     |\n",
      "|     backbone.encoder.layers.1.self_attn.W_V.bias    |     16     |\n",
      "| backbone.encoder.layers.1.self_attn.to_out.0.weight |    256     |\n",
      "|  backbone.encoder.layers.1.self_attn.to_out.0.bias  |     16     |\n",
      "|     backbone.encoder.layers.1.norm_attn.1.weight    |     16     |\n",
      "|      backbone.encoder.layers.1.norm_attn.1.bias     |     16     |\n",
      "|        backbone.encoder.layers.1.ff.0.weight        |    2048    |\n",
      "|         backbone.encoder.layers.1.ff.0.bias         |    128     |\n",
      "|        backbone.encoder.layers.1.ff.3.weight        |    2048    |\n",
      "|         backbone.encoder.layers.1.ff.3.bias         |     16     |\n",
      "|     backbone.encoder.layers.1.norm_ffn.1.weight     |     16     |\n",
      "|      backbone.encoder.layers.1.norm_ffn.1.bias      |     16     |\n",
      "|    backbone.encoder.layers.2.self_attn.W_Q.weight   |    256     |\n",
      "|     backbone.encoder.layers.2.self_attn.W_Q.bias    |     16     |\n",
      "|    backbone.encoder.layers.2.self_attn.W_K.weight   |    256     |\n",
      "|     backbone.encoder.layers.2.self_attn.W_K.bias    |     16     |\n",
      "|    backbone.encoder.layers.2.self_attn.W_V.weight   |    256     |\n",
      "|     backbone.encoder.layers.2.self_attn.W_V.bias    |     16     |\n",
      "| backbone.encoder.layers.2.self_attn.to_out.0.weight |    256     |\n",
      "|  backbone.encoder.layers.2.self_attn.to_out.0.bias  |     16     |\n",
      "|     backbone.encoder.layers.2.norm_attn.1.weight    |     16     |\n",
      "|      backbone.encoder.layers.2.norm_attn.1.bias     |     16     |\n",
      "|        backbone.encoder.layers.2.ff.0.weight        |    2048    |\n",
      "|         backbone.encoder.layers.2.ff.0.bias         |    128     |\n",
      "|        backbone.encoder.layers.2.ff.3.weight        |    2048    |\n",
      "|         backbone.encoder.layers.2.ff.3.bias         |     16     |\n",
      "|     backbone.encoder.layers.2.norm_ffn.1.weight     |     16     |\n",
      "|      backbone.encoder.layers.2.norm_ffn.1.bias      |     16     |\n",
      "+-----------------------------------------------------+------------+\n",
      "Total Trainable Params: 17056\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "17056"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_FC3 = PatchTST_Trans(c_in=7,\n",
    "                    target_dim=96,\n",
    "                    patch_len=12,\n",
    "                    stride=12,\n",
    "                    num_patch=42,\n",
    "                    d_model=16, \n",
    "                    n_heads=4, \n",
    "                    d_ff=128)\n",
    "count_parameters(model_FC3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------------------------------------------------+------------+\n",
      "|                       Modules                       | Parameters |\n",
      "+-----------------------------------------------------+------------+\n",
      "|                    backbone.W_pos                   |    5376    |\n",
      "|                 backbone.W_P.weight                 |    1536    |\n",
      "|                  backbone.W_P.bias                  |    128     |\n",
      "|    backbone.encoder.layers.0.self_attn.W_Q.weight   |   16384    |\n",
      "|     backbone.encoder.layers.0.self_attn.W_Q.bias    |    128     |\n",
      "|    backbone.encoder.layers.0.self_attn.W_K.weight   |   16384    |\n",
      "|     backbone.encoder.layers.0.self_attn.W_K.bias    |    128     |\n",
      "|    backbone.encoder.layers.0.self_attn.W_V.weight   |   16384    |\n",
      "|     backbone.encoder.layers.0.self_attn.W_V.bias    |    128     |\n",
      "| backbone.encoder.layers.0.self_attn.to_out.0.weight |   16384    |\n",
      "|  backbone.encoder.layers.0.self_attn.to_out.0.bias  |    128     |\n",
      "|     backbone.encoder.layers.0.norm_attn.1.weight    |    128     |\n",
      "|      backbone.encoder.layers.0.norm_attn.1.bias     |    128     |\n",
      "|        backbone.encoder.layers.0.ff.0.weight        |   32768    |\n",
      "|         backbone.encoder.layers.0.ff.0.bias         |    256     |\n",
      "|        backbone.encoder.layers.0.ff.3.weight        |   32768    |\n",
      "|         backbone.encoder.layers.0.ff.3.bias         |    128     |\n",
      "|     backbone.encoder.layers.0.norm_ffn.1.weight     |    128     |\n",
      "|      backbone.encoder.layers.0.norm_ffn.1.bias      |    128     |\n",
      "|    backbone.encoder.layers.1.self_attn.W_Q.weight   |   16384    |\n",
      "|     backbone.encoder.layers.1.self_attn.W_Q.bias    |    128     |\n",
      "|    backbone.encoder.layers.1.self_attn.W_K.weight   |   16384    |\n",
      "|     backbone.encoder.layers.1.self_attn.W_K.bias    |    128     |\n",
      "|    backbone.encoder.layers.1.self_attn.W_V.weight   |   16384    |\n",
      "|     backbone.encoder.layers.1.self_attn.W_V.bias    |    128     |\n",
      "| backbone.encoder.layers.1.self_attn.to_out.0.weight |   16384    |\n",
      "|  backbone.encoder.layers.1.self_attn.to_out.0.bias  |    128     |\n",
      "|     backbone.encoder.layers.1.norm_attn.1.weight    |    128     |\n",
      "|      backbone.encoder.layers.1.norm_attn.1.bias     |    128     |\n",
      "|        backbone.encoder.layers.1.ff.0.weight        |   32768    |\n",
      "|         backbone.encoder.layers.1.ff.0.bias         |    256     |\n",
      "|        backbone.encoder.layers.1.ff.3.weight        |   32768    |\n",
      "|         backbone.encoder.layers.1.ff.3.bias         |    128     |\n",
      "|     backbone.encoder.layers.1.norm_ffn.1.weight     |    128     |\n",
      "|      backbone.encoder.layers.1.norm_ffn.1.bias      |    128     |\n",
      "|    backbone.encoder.layers.2.self_attn.W_Q.weight   |   16384    |\n",
      "|     backbone.encoder.layers.2.self_attn.W_Q.bias    |    128     |\n",
      "|    backbone.encoder.layers.2.self_attn.W_K.weight   |   16384    |\n",
      "|     backbone.encoder.layers.2.self_attn.W_K.bias    |    128     |\n",
      "|    backbone.encoder.layers.2.self_attn.W_V.weight   |   16384    |\n",
      "|     backbone.encoder.layers.2.self_attn.W_V.bias    |    128     |\n",
      "| backbone.encoder.layers.2.self_attn.to_out.0.weight |   16384    |\n",
      "|  backbone.encoder.layers.2.self_attn.to_out.0.bias  |    128     |\n",
      "|     backbone.encoder.layers.2.norm_attn.1.weight    |    128     |\n",
      "|      backbone.encoder.layers.2.norm_attn.1.bias     |    128     |\n",
      "|        backbone.encoder.layers.2.ff.0.weight        |   32768    |\n",
      "|         backbone.encoder.layers.2.ff.0.bias         |    256     |\n",
      "|        backbone.encoder.layers.2.ff.3.weight        |   32768    |\n",
      "|         backbone.encoder.layers.2.ff.3.bias         |    128     |\n",
      "|     backbone.encoder.layers.2.norm_ffn.1.weight     |    128     |\n",
      "|      backbone.encoder.layers.2.norm_ffn.1.bias      |    128     |\n",
      "+-----------------------------------------------------+------------+\n",
      "Total Trainable Params: 404480\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "404480"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_FC3 = PatchTST_Trans(c_in=7,\n",
    "                    target_dim=96,\n",
    "                    patch_len=12,\n",
    "                    stride=12,\n",
    "                    num_patch=42,\n",
    "                    d_model=128, \n",
    "                    n_heads=16, \n",
    "                    d_ff=256)\n",
    "count_parameters(model_FC3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_parameters(model)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
