{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0.9041, -0.8796, -0.5632,  0.0797, -1.3168, -0.3419,  1.3030,\n",
      "          -0.3451, -0.6063, -0.5245],\n",
      "         [ 0.4215,  0.6848,  0.5073, -0.9174, -0.6083, -0.1182, -1.5017,\n",
      "          -0.5311, -0.4541, -1.9571],\n",
      "         [ 1.2158, -0.3853,  1.9686,  1.6066, -0.2141,  0.0164,  0.1152,\n",
      "           0.6832, -0.5217, -0.2778],\n",
      "         [ 0.4062, -0.3862,  0.2392, -0.8932, -0.3630,  0.6447, -0.3871,\n",
      "          -0.2544, -0.4813,  1.3624],\n",
      "         [-0.8946, -0.4769, -1.2336, -1.1485,  0.8232,  1.1682,  0.3541,\n",
      "           1.4274,  1.1515, -1.1845],\n",
      "         [ 1.5988, -1.7288, -0.7089, -1.3055, -1.1253, -1.4450, -0.4668,\n",
      "          -0.6925,  1.6598, -0.0912]],\n",
      "\n",
      "        [[ 0.9257,  0.5822, -1.3221, -0.8315,  0.1097, -0.6688, -1.7915,\n",
      "           0.1704,  0.8910,  1.2546],\n",
      "         [ 1.3104, -0.0158,  0.1452,  1.3375,  0.1436, -1.6218,  0.5001,\n",
      "           0.9592,  1.8101,  1.2246],\n",
      "         [-1.8643, -1.6465, -0.6212, -0.1397, -0.2657,  0.1240,  0.5665,\n",
      "           1.3655, -1.0273,  1.9611],\n",
      "         [ 0.6104,  1.1062, -0.2298,  0.2991,  0.2427,  2.1021, -0.7009,\n",
      "          -0.2435,  0.9315,  0.9464],\n",
      "         [ 2.5869,  0.0343, -0.2204, -0.3567, -0.4968, -0.4927,  0.5885,\n",
      "          -0.6755, -0.6929, -0.5042],\n",
      "         [-0.3602,  0.3932,  0.3927, -0.1670,  0.6043, -1.2876,  0.7542,\n",
      "          -0.6074, -0.3500,  0.4357]]])\n",
      "tensor([[-2.2906, -4.4742,  4.2068, -0.1128, -0.0136, -4.3055],\n",
      "        [-0.6803,  5.7931, -1.5476,  5.0642, -0.2296, -0.1919]])\n",
      "torch.Size([2, 6, 10])\n",
      "torch.Size([2, 6])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "N, H, T = 2, 6, 10\n",
    "# Assume a is a tensor of shape [N, H, T]\n",
    "a = torch.randn(N, H, T)  # Example tensor\n",
    "\n",
    "# Step 1: Sum along the T dimension to get tensor b of shape [N, H]\n",
    "b = torch.sum(a, dim=2)  # Shape of b is [N, H]\n",
    "\n",
    "print(a)\n",
    "print(b)\n",
    "print(a.shape)\n",
    "print(b.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-2.2906000000000004"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "0.9041 -0.8796-0.5632+  0.0797 -1.3168 -0.3419+  1.3030-0.3451-0.6063-0.5245"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2, 4, 3, 0],\n",
      "        [1, 3, 5, 4]])\n",
      "torch.Size([2, 4])\n"
     ]
    }
   ],
   "source": [
    "# Step 2: Select the top k values along the H dimension\n",
    "# We use torch.topk to get the indices of the top k values\n",
    "k = 4  # Specify the value of k\n",
    "_, topk_indices = torch.topk(b, k, dim=1)  # Shape of topk_indices is [N, k]\n",
    "\n",
    "print(topk_indices)\n",
    "print(topk_indices.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],\n",
      "         [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],\n",
      "         [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n",
      "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],\n",
      "\n",
      "        [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "         [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n",
      "         [5, 5, 5, 5, 5, 5, 5, 5, 5, 5],\n",
      "         [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]])\n",
      "torch.Size([2, 4, 10])\n",
      "tensor([[[ 1.2158, -0.3853,  1.9686,  1.6066, -0.2141,  0.0164,  0.1152,\n",
      "           0.6832, -0.5217, -0.2778],\n",
      "         [-0.8946, -0.4769, -1.2336, -1.1485,  0.8232,  1.1682,  0.3541,\n",
      "           1.4274,  1.1515, -1.1845],\n",
      "         [ 0.4062, -0.3862,  0.2392, -0.8932, -0.3630,  0.6447, -0.3871,\n",
      "          -0.2544, -0.4813,  1.3624],\n",
      "         [ 0.9041, -0.8796, -0.5632,  0.0797, -1.3168, -0.3419,  1.3030,\n",
      "          -0.3451, -0.6063, -0.5245]],\n",
      "\n",
      "        [[ 1.3104, -0.0158,  0.1452,  1.3375,  0.1436, -1.6218,  0.5001,\n",
      "           0.9592,  1.8101,  1.2246],\n",
      "         [ 0.6104,  1.1062, -0.2298,  0.2991,  0.2427,  2.1021, -0.7009,\n",
      "          -0.2435,  0.9315,  0.9464],\n",
      "         [-0.3602,  0.3932,  0.3927, -0.1670,  0.6043, -1.2876,  0.7542,\n",
      "          -0.6074, -0.3500,  0.4357],\n",
      "         [ 2.5869,  0.0343, -0.2204, -0.3567, -0.4968, -0.4927,  0.5885,\n",
      "          -0.6755, -0.6929, -0.5042]]])\n",
      "torch.Size([2, 4, 10])\n"
     ]
    }
   ],
   "source": [
    "# Step 3: Use the indices to select the top k heads from the original tensor a\n",
    "# Use torch.gather to select along the H dimension\n",
    "topk_indices_expanded = topk_indices.unsqueeze(2).expand(-1, -1, T)  # Expand indices for proper gathering\n",
    "a_topk = torch.gather(a, 1, topk_indices_expanded)  # Shape of a_topk is [N, k, T]\n",
    "\n",
    "print(topk_indices_expanded)\n",
    "print(topk_indices_expanded.shape)\n",
    "print(a_topk)\n",
    "print(a_topk.shape)\n",
    "# Now a_topk has shape [N, k, T]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-1.6234,  0.5716, -3.0238, -0.1030, -0.6727, -0.0131],\n",
      "         [ 0.3394, -0.5712,  1.8612,  0.1429, -2.4633, -0.9110],\n",
      "         [ 0.1373,  0.6970, -1.1414,  1.7767,  0.6732,  1.2749],\n",
      "         [-0.8944,  1.0215, -1.2825, -0.3689, -0.0744, -0.2944],\n",
      "         [ 1.0076, -2.0619, -0.3731, -0.2602, -0.0738,  1.2605],\n",
      "         [ 0.9844,  1.2328, -0.6156,  0.2858, -0.1262,  0.6503],\n",
      "         [-0.8435,  0.4985,  1.4154,  1.6093, -1.0917,  0.1000],\n",
      "         [ 1.2428, -0.0728, -0.5672, -0.2836, -0.2023,  0.7800],\n",
      "         [ 2.0718, -1.2521,  0.1359, -1.3609,  0.6042,  0.7212],\n",
      "         [-0.5903,  0.2496, -0.1047,  0.9970, -0.3020,  0.2568]],\n",
      "\n",
      "        [[ 0.3867,  0.4413, -0.1570,  0.6820, -0.4854,  1.6211],\n",
      "         [-0.1357,  1.1282, -0.9414,  1.3190, -0.0502, -0.7298],\n",
      "         [ 1.5035,  0.0686, -0.3460, -1.1703,  0.4468, -2.4669],\n",
      "         [-0.3202,  0.7076,  0.8119, -0.9865,  0.3860,  2.0085],\n",
      "         [ 1.3758, -1.3889,  0.2530,  0.2001,  0.1194,  0.6831],\n",
      "         [-0.7684, -0.4617, -0.9504,  0.8323, -0.4154,  0.1932],\n",
      "         [-0.8428,  0.8054, -0.4595, -0.2538,  1.6860,  0.5545],\n",
      "         [ 0.7989, -1.2042, -0.9260,  0.2231, -0.7989, -0.6381],\n",
      "         [ 0.8892, -0.8179, -0.9941,  0.4395,  0.6840, -0.9122],\n",
      "         [ 0.0488, -0.0660, -0.9041, -1.6147, -0.2302,  1.3497]]])\n",
      "tensor([[-4.8645, -1.6021,  3.4178, -1.8931, -0.5008,  2.4115,  1.6880,  0.8970,\n",
      "          0.9200,  0.5064],\n",
      "        [ 2.4887,  0.5902, -1.9644,  2.6073,  1.2425, -1.5704,  1.4897, -2.5452,\n",
      "         -0.7116, -1.4165]])\n",
      "torch.Size([2, 10, 6])\n",
      "torch.Size([2, 10])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "N, H, T = 2, 10, 6\n",
    "\n",
    "# Assume a is a tensor of shape [N, H, T]\n",
    "a = torch.randn(N, H, T)  # Example tensor\n",
    "\n",
    "# Step 1: Sum along the T dimension to get tensor b of shape [N, H]\n",
    "b = torch.sum(a, dim=2)  # Shape of b is [N, H]\n",
    "\n",
    "print(a)\n",
    "print(b)\n",
    "print(a.shape)\n",
    "print(b.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-4.8645, -1.6021],\n",
      "         [ 3.4178, -1.8931],\n",
      "         [-0.5008,  2.4115],\n",
      "         [ 1.6880,  0.8970],\n",
      "         [ 0.9200,  0.5064]],\n",
      "\n",
      "        [[ 2.4887,  0.5902],\n",
      "         [-1.9644,  2.6073],\n",
      "         [ 1.2425, -1.5704],\n",
      "         [ 1.4897, -2.5452],\n",
      "         [-0.7116, -1.4165]]])\n",
      "torch.Size([2, 5, 2])\n"
     ]
    }
   ],
   "source": [
    "# Step 2: Reshape the tensor b into [N, k, H//k]\n",
    "k = 5  # Specify the value of k (number of groups)\n",
    "H_group_size = H // k  # Each group will have H//k elements\n",
    "\n",
    "# Reshape b to [N, k, H//k]\n",
    "b_reshaped = b.view(N, k, H_group_size)  # Shape of b_reshaped is [N, k, H//k]\n",
    "\n",
    "print(b_reshaped)\n",
    "print(b_reshaped.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[1],\n",
      "         [0],\n",
      "         [1],\n",
      "         [0],\n",
      "         [0]],\n",
      "\n",
      "        [[0],\n",
      "         [1],\n",
      "         [0],\n",
      "         [0],\n",
      "         [0]]])\n",
      "tensor([[[1, 1, 1, 1, 1, 1],\n",
      "         [0, 0, 0, 0, 0, 0],\n",
      "         [1, 1, 1, 1, 1, 1],\n",
      "         [0, 0, 0, 0, 0, 0],\n",
      "         [0, 0, 0, 0, 0, 0]],\n",
      "\n",
      "        [[0, 0, 0, 0, 0, 0],\n",
      "         [1, 1, 1, 1, 1, 1],\n",
      "         [0, 0, 0, 0, 0, 0],\n",
      "         [0, 0, 0, 0, 0, 0],\n",
      "         [0, 0, 0, 0, 0, 0]]])\n",
      "torch.Size([2, 5, 6])\n",
      "tensor([[[ 0.3394, -0.5712,  1.8612,  0.1429, -2.4633, -0.9110],\n",
      "         [ 0.1373,  0.6970, -1.1414,  1.7767,  0.6732,  1.2749],\n",
      "         [ 0.9844,  1.2328, -0.6156,  0.2858, -0.1262,  0.6503],\n",
      "         [-0.8435,  0.4985,  1.4154,  1.6093, -1.0917,  0.1000],\n",
      "         [ 2.0718, -1.2521,  0.1359, -1.3609,  0.6042,  0.7212]],\n",
      "\n",
      "        [[ 0.3867,  0.4413, -0.1570,  0.6820, -0.4854,  1.6211],\n",
      "         [-0.3202,  0.7076,  0.8119, -0.9865,  0.3860,  2.0085],\n",
      "         [ 1.3758, -1.3889,  0.2530,  0.2001,  0.1194,  0.6831],\n",
      "         [-0.8428,  0.8054, -0.4595, -0.2538,  1.6860,  0.5545],\n",
      "         [ 0.8892, -0.8179, -0.9941,  0.4395,  0.6840, -0.9122]]])\n",
      "torch.Size([2, 5, 6])\n"
     ]
    }
   ],
   "source": [
    "# Step 3: Select the top value in each group (1 value per group)\n",
    "# Use torch.topk to get the indices of the top values within each group\n",
    "_, topk_indices = torch.topk(b_reshaped, 1, dim=2)  # Shape of topk_indices is [N, k, 1]\n",
    "\n",
    "# Step 4: Use the indices to gather the top values from the original tensor a\n",
    "# First, expand topk_indices to match the shape for gathering from a\n",
    "topk_indices_expanded = topk_indices.squeeze(2).unsqueeze(2).expand(-1, -1, T)  # Shape [N, k, T]\n",
    "\n",
    "# Gather the topk heads from the original tensor a\n",
    "a_topk = torch.gather(a.view(N, k, H_group_size, T), 2, topk_indices_expanded.unsqueeze(2)).squeeze(2)\n",
    "\n",
    "# Now, a_topk has shape [N, k, T]\n",
    "\n",
    "print(topk_indices)\n",
    "print(topk_indices_expanded)\n",
    "print(topk_indices_expanded.shape)\n",
    "print(a_topk)\n",
    "print(a_topk.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 4, 6])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# Example shapes\n",
    "N, T, HID = 2, 4, 6\n",
    "\n",
    "# Random tensor a of shape (N, T, HID)\n",
    "a = torch.randn(N, T, HID)\n",
    "\n",
    "# Random position_ids of shape (N, T), values should be within the range [0, T-1]\n",
    "position_ids = torch.randint(0, T, (N, T))\n",
    "\n",
    "# Gather values from a according to position_ids along the T dimension (dim=1)\n",
    "result = torch.gather(a, 1, position_ids.unsqueeze(-1).expand(N, T, HID))\n",
    "\n",
    "print(result.shape)  # This should print (N, T, HID)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[-0.7926, -0.8541, -0.1556, -1.0184,  0.1551,  1.0674],\n",
      "         [-0.0979,  1.4511, -1.9776,  1.0785,  1.0232,  1.0687],\n",
      "         [ 1.3305, -0.8816,  0.8945,  0.8439, -1.1265,  0.2250],\n",
      "         [ 2.3407, -0.5805,  0.7658, -0.6370,  0.5458, -0.9757]],\n",
      "\n",
      "        [[ 1.1554,  0.7077,  0.7813,  0.5542, -0.1342,  0.4433],\n",
      "         [ 1.6541,  0.8458,  1.5027,  0.5250, -1.9654,  0.7925],\n",
      "         [-2.8540,  0.6407, -1.1707,  1.3274,  0.3265, -1.4777],\n",
      "         [ 0.8109,  1.2004, -0.3398, -0.1448,  1.2881,  0.3154]]])\n",
      "torch.Size([2, 4, 6])\n",
      "tensor([[2, 1, 1, 3],\n",
      "        [1, 1, 3, 3]])\n",
      "torch.Size([2, 4])\n",
      "tensor([[[ 1.3305, -0.8816,  0.8945,  0.8439, -1.1265,  0.2250],\n",
      "         [-0.0979,  1.4511, -1.9776,  1.0785,  1.0232,  1.0687],\n",
      "         [-0.0979,  1.4511, -1.9776,  1.0785,  1.0232,  1.0687],\n",
      "         [ 2.3407, -0.5805,  0.7658, -0.6370,  0.5458, -0.9757]],\n",
      "\n",
      "        [[ 1.6541,  0.8458,  1.5027,  0.5250, -1.9654,  0.7925],\n",
      "         [ 1.6541,  0.8458,  1.5027,  0.5250, -1.9654,  0.7925],\n",
      "         [ 0.8109,  1.2004, -0.3398, -0.1448,  1.2881,  0.3154],\n",
      "         [ 0.8109,  1.2004, -0.3398, -0.1448,  1.2881,  0.3154]]])\n",
      "torch.Size([2, 4, 6])\n"
     ]
    }
   ],
   "source": [
    "print(a)\n",
    "print(a.shape)\n",
    "print(position_ids)\n",
    "print(position_ids.shape)\n",
    "print(result)\n",
    "print(result.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0.0228,  0.9206, -0.3209,  0.2776,  0.0873,  0.0663],\n",
      "         [ 0.1255,  0.9391, -1.2239,  1.9296, -0.8529,  0.5387],\n",
      "         [-0.4250, -0.3241, -0.3943, -0.1539, -0.4990,  1.0619],\n",
      "         [-1.2678, -1.7241,  0.4328, -1.7619,  0.9449,  0.4025]],\n",
      "\n",
      "        [[ 0.7703, -1.0317, -1.5817, -0.6732, -0.3202,  0.6651],\n",
      "         [ 0.6113, -1.5711,  0.9737, -0.4005, -0.4183, -1.4413],\n",
      "         [ 1.0087, -0.3567,  1.4314, -2.1380,  0.9284, -0.1183],\n",
      "         [ 0.5450, -1.0610,  1.0568, -0.8993,  0.0670,  1.2972]]])\n",
      "torch.Size([2, 4, 6])\n"
     ]
    }
   ],
   "source": [
    "print(a)\n",
    "print(a.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 1, 2, 3],\n",
       "        [0, 1, 2, 3]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "position_ids = torch.arange(T).repeat(N, 1)\n",
    "position_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "result = torch.gather(a, 1, position_ids.unsqueeze(-1).expand(N, T, HID))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0.0228,  0.9206, -0.3209,  0.2776,  0.0873,  0.0663],\n",
      "         [ 0.1255,  0.9391, -1.2239,  1.9296, -0.8529,  0.5387],\n",
      "         [-0.4250, -0.3241, -0.3943, -0.1539, -0.4990,  1.0619],\n",
      "         [-1.2678, -1.7241,  0.4328, -1.7619,  0.9449,  0.4025]],\n",
      "\n",
      "        [[ 0.7703, -1.0317, -1.5817, -0.6732, -0.3202,  0.6651],\n",
      "         [ 0.6113, -1.5711,  0.9737, -0.4005, -0.4183, -1.4413],\n",
      "         [ 1.0087, -0.3567,  1.4314, -2.1380,  0.9284, -0.1183],\n",
      "         [ 0.5450, -1.0610,  1.0568, -0.8993,  0.0670,  1.2972]]])\n",
      "torch.Size([2, 4, 6])\n"
     ]
    }
   ],
   "source": [
    "print(result)\n",
    "print(result.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 4.6774e-01,  1.6003e+00, -1.1803e+00, -1.1767e+00,  7.6288e-01,\n",
      "          -5.8245e-01],\n",
      "         [-9.4216e-01,  2.0977e-03,  9.1515e-01,  2.7961e-01, -1.1394e+00,\n",
      "          -2.2370e+00],\n",
      "         [-4.4635e-01,  9.4994e-01,  1.1487e-01, -1.4899e-01,  2.3911e-02,\n",
      "           3.0744e-01],\n",
      "         [-1.5316e-01,  2.0130e-01, -8.4066e-04,  1.3562e+00,  8.3957e-02,\n",
      "           5.2856e-01]],\n",
      "\n",
      "        [[ 1.4534e+00,  6.7893e-01,  7.2445e-01,  4.8842e-01, -3.3182e-01,\n",
      "           1.9541e-01],\n",
      "         [ 1.2886e+00, -1.1761e+00,  1.2085e+00,  1.7261e-01, -4.8055e-01,\n",
      "           5.4247e-02],\n",
      "         [-4.8610e-01,  5.1687e-01,  5.1413e-01,  6.0117e-03, -2.2492e+00,\n",
      "          -3.3767e-01],\n",
      "         [-1.0187e+00,  1.8028e-01,  5.2256e-01, -1.0976e+00,  2.2894e-01,\n",
      "           1.0910e+00]]])\n",
      "tensor([[0],\n",
      "        [0]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# Example shapes\n",
    "N, T, HID = 2, 4, 6\n",
    "\n",
    "# Random tensor a of shape (N, T, HID)\n",
    "a = torch.randn(N, T, HID)\n",
    "\n",
    "# Random position_ids of shape (N, T), values should be within the range [0, T-1]\n",
    "position_ids = torch.randint(0, 1, (N, 1))\n",
    "\n",
    "print(a)\n",
    "print(position_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 4, 6])\n",
      "torch.Size([2, 1])\n"
     ]
    }
   ],
   "source": [
    "print(a.shape)\n",
    "print(position_ids.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 4, 6])\n",
      "tensor([[[ 0.4677,  1.6003, -1.1803, -1.1767,  0.7629, -0.5824],\n",
      "         [ 0.4677,  1.6003, -1.1803, -1.1767,  0.7629, -0.5824],\n",
      "         [ 0.4677,  1.6003, -1.1803, -1.1767,  0.7629, -0.5824],\n",
      "         [ 0.4677,  1.6003, -1.1803, -1.1767,  0.7629, -0.5824]],\n",
      "\n",
      "        [[ 1.4534,  0.6789,  0.7245,  0.4884, -0.3318,  0.1954],\n",
      "         [ 1.4534,  0.6789,  0.7245,  0.4884, -0.3318,  0.1954],\n",
      "         [ 1.4534,  0.6789,  0.7245,  0.4884, -0.3318,  0.1954],\n",
      "         [ 1.4534,  0.6789,  0.7245,  0.4884, -0.3318,  0.1954]]])\n"
     ]
    }
   ],
   "source": [
    "# Gather values from a according to position_ids along the T dimension (dim=1)\n",
    "result = torch.gather(a, 1, position_ids.unsqueeze(-1).expand(N, T, HID))\n",
    "\n",
    "print(result.shape)  # This should print (N, T, HID)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 1, 6])\n",
      "tensor([[[ 0.4677,  1.6003, -1.1803, -1.1767,  0.7629, -0.5824]],\n",
      "\n",
      "        [[ 1.4534,  0.6789,  0.7245,  0.4884, -0.3318,  0.1954]]])\n"
     ]
    }
   ],
   "source": [
    "# Gather values from a according to position_ids along the T dimension (dim=1)\n",
    "result = torch.gather(a, 1, position_ids.unsqueeze(-1).expand(N, 1, HID))\n",
    "\n",
    "print(result.shape)  # This should print (N, 1, HID)\n",
    "print(result)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
