{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "8dcd9d44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original tensor: tensor([10,  5, 10, 15, 20])\n",
      "Output matrix:\n",
      " tensor([[1, 0, 1, 2, 2],\n",
      "        [2, 1, 2, 2, 2],\n",
      "        [1, 0, 1, 2, 2],\n",
      "        [0, 0, 0, 1, 2],\n",
      "        [0, 0, 0, 0, 1]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# Define the tensor\n",
    "tensor = torch.tensor([10, 5, 10, 15, 20])\n",
    "\n",
    "# Get the number of elements in the tensor\n",
    "n = tensor.size(0)\n",
    "\n",
    "# Create an empty matrix to store the results\n",
    "output_matrix = torch.zeros((n, n), dtype=torch.long)\n",
    "\n",
    "# Compare each element with every other element\n",
    "for i in range(n):\n",
    "    # Extract the current element to compare against\n",
    "    current_element = tensor[i]\n",
    "    \n",
    "    # Fill the row of the matrix with the correct comparison values\n",
    "    output_matrix[i] = torch.where(tensor < current_element, \n",
    "                                   torch.tensor(0, dtype=torch.long),\n",
    "                                   torch.where(tensor == current_element, \n",
    "                                               torch.tensor(1, dtype=torch.long),\n",
    "                                               torch.tensor(2, dtype=torch.long)))\n",
    "\n",
    "print(\"Original tensor:\", tensor)\n",
    "print(\"Output matrix:\\n\", output_matrix)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "1c09ef34",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([25])"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_matrix.flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "b27b29fa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[100,  50, 100, 150, 200],\n",
       "        [ 50,  25,  50,  75, 100],\n",
       "        [100,  50, 100, 150, 200],\n",
       "        [150,  75, 150, 225, 300],\n",
       "        [200, 100, 200, 300, 400]])"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor @ tensor.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c5648b21",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(850)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor @ tensor.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "9431f878",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding matrix shape: torch.Size([5, 128])\n",
      "Pairwise embeddings matrix shape: torch.Size([25, 256])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# Define the embedding matrix (5 x 128)\n",
    "embedding_matrix = torch.randn(5, 128)\n",
    "\n",
    "# Get the number of rows in the embedding matrix\n",
    "num_rows = embedding_matrix.size(0)\n",
    "\n",
    "# Initialize a list to store pairwise concatenated embeddings\n",
    "pairwise_embeddings = []\n",
    "\n",
    "# Generate pairwise concatenated embeddings\n",
    "for i in range(num_rows):\n",
    "    for j in range(num_rows):  # Ensure pairs are unique\n",
    "        # Concatenate embeddings from row i and row j\n",
    "        concatenated = torch.cat((embedding_matrix[i], embedding_matrix[j]), dim=0)\n",
    "        pairwise_embeddings.append(concatenated)\n",
    "\n",
    "# Convert list of tensors to a single tensor\n",
    "pairwise_embeddings_matrix = torch.stack(pairwise_embeddings)\n",
    "\n",
    "# Print shapes to verify\n",
    "print(\"Embedding matrix shape:\", embedding_matrix.shape)\n",
    "print(\"Pairwise embeddings matrix shape:\", pairwise_embeddings_matrix.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "55f9db3d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-2.0134e+00,  1.4611e+00, -7.8293e-01,  6.9609e-01,  4.8444e-01,\n",
       "          1.1393e+00, -8.5221e-01, -1.6264e+00, -6.7714e-01, -2.4328e+00,\n",
       "          4.7671e-01, -1.1517e+00, -1.4596e+00, -5.0996e-01,  2.2365e-01,\n",
       "         -6.1374e-02, -1.0140e+00,  6.1356e-01, -8.4270e-02, -1.0102e+00,\n",
       "          4.4242e-01, -1.9984e+00,  1.2542e+00,  7.1929e-01, -5.6129e-01,\n",
       "          3.1783e-01,  4.9972e-01,  1.7425e-01,  4.8911e-01, -5.4969e-01,\n",
       "         -5.4936e-01,  5.1140e-01, -3.9898e-01, -1.0776e+00,  1.4730e-02,\n",
       "         -9.1950e-01,  7.3325e-01,  1.1437e+00, -1.2801e-01,  4.3065e-01,\n",
       "         -1.0361e+00, -1.8642e+00, -7.8827e-01, -1.5666e+00,  7.4162e-01,\n",
       "          1.1092e+00,  3.1726e-01, -1.1296e-01, -1.2023e+00,  1.6326e+00,\n",
       "          7.2594e-02, -1.0079e+00, -3.3135e-01,  2.0921e-01, -6.1901e-01,\n",
       "          6.4580e-01,  1.9848e+00, -7.3644e-01, -6.6276e-01, -8.6202e-01,\n",
       "          1.0710e+00, -5.1565e-02,  1.4550e+00,  1.1163e+00, -1.1616e-01,\n",
       "         -2.9639e-01, -7.6344e-01, -8.8615e-01, -3.4026e-01, -5.4889e-01,\n",
       "          1.1438e-02, -5.3045e-01, -8.1307e-01,  6.4803e-02, -3.7952e-01,\n",
       "         -2.5470e-01,  9.0181e-02,  1.3177e+00, -2.2659e+00, -3.1069e-01,\n",
       "         -9.2627e-02,  1.3178e+00, -3.3976e-01, -1.5849e-01,  2.5209e-01,\n",
       "         -6.5342e-01,  9.3370e-01, -3.5894e-01,  3.3042e-01, -2.4821e+00,\n",
       "          1.1851e+00, -2.3584e-01, -3.6034e-01, -1.4563e+00, -1.5143e+00,\n",
       "         -1.5747e+00, -2.7185e-01, -1.4397e+00, -5.2866e-01, -8.4599e-01,\n",
       "          9.2533e-02, -6.8697e-01, -1.0181e+00,  8.9502e-01, -1.1607e+00,\n",
       "         -1.5396e+00,  8.3605e-02, -1.7824e+00,  9.8231e-01, -2.1286e-01,\n",
       "          1.3434e-01,  3.1161e-01,  4.2628e-01,  2.3278e-01,  1.3003e+00,\n",
       "         -5.4635e-01, -3.6405e-01,  1.0227e+00,  5.3121e-01,  4.4347e-01,\n",
       "         -5.2786e-01, -3.6400e-01, -1.7958e+00, -4.8571e-01, -9.5088e-02,\n",
       "         -1.0532e+00, -5.6458e-01,  6.3541e-01],\n",
       "        [-1.1582e+00,  1.5823e+00,  2.0632e-01, -7.1496e-01, -2.1350e-01,\n",
       "          2.9906e-01, -1.4252e+00,  6.6909e-01, -5.6822e-02, -1.3616e+00,\n",
       "         -8.1868e-01, -1.7504e+00, -7.6134e-02,  5.5874e-01, -2.4673e-02,\n",
       "         -5.9711e-01, -1.1256e-01,  1.0841e+00, -5.4436e-01, -1.7183e+00,\n",
       "          9.1595e-01, -1.5331e+00, -1.0924e-01,  1.7293e-01, -1.4127e+00,\n",
       "          7.5891e-01, -3.3650e-01, -1.0880e+00,  1.0666e+00, -5.1553e-01,\n",
       "         -5.5053e-01,  1.1572e-01, -4.0138e-01,  1.7817e-01,  2.6834e-01,\n",
       "          1.1809e+00,  6.5813e-02,  1.2164e+00,  1.9828e+00,  2.6972e+00,\n",
       "          3.7548e-02,  6.8981e-01,  6.3364e-02,  3.7722e-02, -8.0755e-02,\n",
       "         -5.5779e-01,  6.9424e-01, -5.0134e-01, -7.4197e-01,  4.5543e-01,\n",
       "         -3.7099e-01,  1.5767e+00,  9.9549e-01, -2.8606e-01, -1.5824e-01,\n",
       "          1.5116e+00,  3.3429e-01,  6.8029e-01,  6.0226e-02,  1.1364e+00,\n",
       "          8.8543e-01,  3.9557e-01,  1.5737e+00, -1.8335e+00,  1.2930e+00,\n",
       "          1.5532e+00,  2.2958e+00, -7.6574e-01, -6.3778e-01,  1.9213e-01,\n",
       "         -1.3886e-01, -2.1224e-01,  1.7816e+00,  2.7320e-01,  4.5374e-01,\n",
       "         -3.5682e-01,  1.2188e-01, -1.1476e+00,  8.2890e-01,  6.4790e-01,\n",
       "         -3.8155e-01,  4.0475e-01, -2.9499e-01, -9.2163e-01,  3.7768e-01,\n",
       "         -1.6264e+00,  7.6664e-01,  5.2265e-01,  1.0340e+00,  1.2185e+00,\n",
       "         -3.6933e-01, -1.2207e+00,  5.5706e-02, -4.6160e-01,  1.0566e+00,\n",
       "         -5.0346e-01, -5.4531e-01,  2.5287e-01, -7.4296e-01, -4.5252e-01,\n",
       "          6.6108e-02,  9.7092e-02, -2.2537e-01,  9.9704e-01, -1.4082e-01,\n",
       "         -2.7469e-01,  5.5254e-01, -1.9273e+00, -1.0317e+00, -7.1823e-01,\n",
       "         -1.3326e+00, -4.7222e-01, -4.4600e-01,  6.9247e-01, -3.5115e-01,\n",
       "          1.6572e+00,  1.9461e-01,  1.0581e-01, -8.8518e-01, -1.1066e+00,\n",
       "         -1.2509e+00,  8.0872e-01, -3.1122e-01,  3.7622e-01, -8.1572e-01,\n",
       "         -5.1176e-01,  1.2119e+00,  6.1257e-01],\n",
       "        [ 4.2448e-01,  1.6528e-01,  4.6147e-01, -8.0854e-01,  1.9352e+00,\n",
       "          8.2445e-01, -1.0877e-01, -1.8159e-01,  1.6081e-01,  1.2900e+00,\n",
       "         -1.0351e+00, -9.7531e-01,  8.8397e-01, -6.4113e-01, -8.2553e-02,\n",
       "          8.0327e-01,  6.1016e-01,  8.0209e-01,  3.8444e-01,  7.6869e-02,\n",
       "          5.6899e-01, -1.1509e-01, -1.2009e+00, -1.4567e+00, -1.6180e+00,\n",
       "          1.3754e+00, -1.2696e+00,  6.1615e-01, -7.7708e-01,  1.4437e+00,\n",
       "          1.0889e+00, -3.5656e-01,  7.1678e-01,  9.1038e-01,  1.4040e+00,\n",
       "          7.2516e-01, -6.4983e-01, -7.1436e-01,  1.3298e+00,  1.9671e+00,\n",
       "          1.5788e+00, -1.4046e+00, -5.8867e-01,  9.6543e-01,  1.1703e+00,\n",
       "          1.6266e+00, -4.7478e-02, -1.5346e+00, -7.4536e-01, -1.4207e+00,\n",
       "          3.1010e-02,  1.1057e+00, -1.4179e+00,  6.7197e-01, -6.5861e-01,\n",
       "         -7.0904e-01, -5.5495e-01, -2.1004e+00,  5.2932e-01, -5.1431e-01,\n",
       "         -7.6594e-01,  4.7816e-01, -2.6932e-02, -5.7050e-01,  1.4703e-01,\n",
       "         -1.3869e+00,  1.1982e+00, -4.3753e-01,  5.8945e-02, -7.2775e-02,\n",
       "         -7.3746e-01,  7.9348e-01,  1.2087e+00, -2.4538e-01, -7.6416e-01,\n",
       "         -7.5206e-01,  4.5041e-01,  1.3988e+00, -5.6755e-01, -1.8313e-01,\n",
       "          1.1894e+00, -6.2169e-01, -1.5720e+00,  5.2607e-01,  3.9926e-01,\n",
       "         -2.3156e-01, -1.2486e+00, -4.7351e-01, -6.6318e-01, -5.0326e-01,\n",
       "         -2.6217e-01,  3.4123e-01,  7.6025e-01,  1.6029e+00, -1.5241e-01,\n",
       "          1.2605e+00, -5.3016e-02, -2.1379e-01, -6.1525e-01,  1.7114e+00,\n",
       "         -9.2737e-01,  5.9773e-01,  1.0443e-01, -1.1301e+00, -7.0624e-01,\n",
       "          6.6365e-01, -6.6577e-01, -6.0117e-01,  5.3345e-01, -3.4066e-01,\n",
       "         -6.0917e-01, -6.9918e-01,  3.1166e-01,  1.2914e+00,  1.7797e-01,\n",
       "          5.0615e-01,  1.7035e+00,  1.5324e+00, -1.9860e+00, -3.7501e-01,\n",
       "          5.9921e-01, -2.3895e-01,  4.7787e-03, -1.1232e+00, -1.1899e-01,\n",
       "          8.9119e-01,  8.2807e-01,  4.1147e-01],\n",
       "        [-3.2274e-01, -1.6447e+00, -1.4941e+00,  1.5564e+00,  8.7577e-01,\n",
       "         -4.0664e-01,  2.9437e-01,  1.0330e+00,  1.0412e-01,  7.2124e-01,\n",
       "          1.0017e+00,  1.3505e+00, -3.5812e-01, -6.1002e-01, -9.1083e-01,\n",
       "         -5.8789e-01,  1.4396e+00, -1.9612e+00, -1.4722e+00, -1.1682e-01,\n",
       "         -2.5402e+00, -1.5088e+00,  1.3657e+00, -2.4337e-01, -1.2388e-01,\n",
       "         -1.8182e-01, -1.0860e+00,  4.1020e-01,  1.5450e+00, -6.9139e-01,\n",
       "         -1.2327e+00, -6.5037e-01,  1.1460e+00, -7.3592e-01, -3.9285e-01,\n",
       "          2.7656e-04, -1.5939e+00, -7.8921e-02,  5.3482e-01,  3.9811e-01,\n",
       "          2.7338e-01, -2.2951e+00, -5.7553e-01,  6.5358e-01, -1.0165e+00,\n",
       "          1.8180e+00,  7.8271e-01, -1.4322e+00, -7.0498e-01, -1.4650e-01,\n",
       "          9.6470e-04,  7.5755e-01, -8.1531e-01,  8.9939e-03, -1.1252e+00,\n",
       "         -6.7457e-02, -1.3964e-01, -1.4068e+00,  1.7487e+00,  2.9406e-01,\n",
       "         -1.1453e+00,  1.9082e+00, -1.5543e+00,  6.1518e-01, -1.1388e+00,\n",
       "          1.8432e+00, -7.0646e-01, -1.7114e+00, -8.2441e-01,  2.0304e+00,\n",
       "         -2.6857e-01,  3.5826e+00, -3.4273e-02,  1.8874e+00, -3.5663e-01,\n",
       "          3.5015e-02,  1.5985e-01,  2.0182e-01,  6.5722e-01,  3.2553e-01,\n",
       "         -3.3572e-01, -7.7467e-01, -5.7608e-01, -1.6223e+00,  6.4393e-01,\n",
       "         -6.9665e-02,  1.3039e+00,  8.0893e-01, -2.6916e-01, -2.4810e-01,\n",
       "         -1.2325e+00, -1.8820e-01,  6.2564e-01, -5.1070e-04, -3.9301e-01,\n",
       "          5.0596e-01, -1.4205e-01, -1.2372e+00,  1.2906e+00,  7.3505e-01,\n",
       "          9.8374e-01,  6.1954e-01, -5.4411e-01,  1.4688e+00,  7.8158e-01,\n",
       "         -1.2972e+00, -3.8805e-03,  1.1372e+00, -3.7641e-01,  4.6800e-02,\n",
       "          1.1840e+00, -1.5828e+00, -9.3094e-01,  7.5609e-01,  4.7920e-01,\n",
       "         -3.0513e-01,  7.1563e-02,  1.5887e+00, -6.6971e-01,  7.4480e-01,\n",
       "          7.6107e-01, -7.8250e-01, -2.3769e+00,  1.9786e+00,  5.4449e-01,\n",
       "         -3.4245e-03, -8.2969e-01, -2.2529e-01],\n",
       "        [ 9.6998e-01, -1.5924e+00,  4.8953e-01,  5.2747e-01, -7.1090e-01,\n",
       "         -6.6565e-01,  4.4866e-01,  4.4399e-01,  5.0584e-01, -4.6529e-01,\n",
       "          7.1161e-02, -2.5923e-01, -8.7487e-01,  1.4718e+00,  5.0782e-01,\n",
       "          1.8166e-01,  9.2053e-02,  7.2098e-01, -4.9100e-01,  1.1587e+00,\n",
       "         -7.6289e-01, -4.4984e-01, -8.2588e-01, -7.2844e-02, -1.3531e+00,\n",
       "         -4.8092e-01,  2.7050e-02,  2.9040e-01,  9.3783e-01, -6.8764e-01,\n",
       "         -3.8456e-01, -1.3026e+00, -1.0518e+00,  8.6259e-01, -8.9978e-01,\n",
       "          8.6357e-01, -5.7688e-01,  7.8686e-01, -8.1527e-01, -2.0924e+00,\n",
       "         -2.7387e-01,  5.6314e-01,  5.2758e-01, -3.7329e-01, -9.4607e-01,\n",
       "          1.3547e+00,  1.5794e+00,  5.1166e-01,  4.2117e-01, -3.4773e-01,\n",
       "          1.0534e+00, -2.3076e-01,  1.5884e+00,  4.9079e-01,  1.1064e+00,\n",
       "         -2.1931e-01, -2.0133e-01, -1.3079e+00,  5.0017e-01, -3.0828e-01,\n",
       "         -2.0037e+00,  5.5012e-01, -4.3748e-01, -3.0356e-01, -7.8319e-01,\n",
       "          1.8885e+00, -2.6416e-01, -1.2861e+00,  5.5669e-01,  1.0294e+00,\n",
       "         -1.3725e+00, -1.5172e+00, -3.2086e-01, -1.5413e+00,  8.1616e-01,\n",
       "          2.0347e+00,  6.8251e-01, -4.4580e-01,  8.6337e-01,  5.6324e-01,\n",
       "         -6.0532e-02, -6.0403e-01,  5.0747e-01,  1.2301e+00,  3.3168e-02,\n",
       "          7.1696e-01,  1.1622e+00, -5.5628e-01, -4.6337e-01,  1.2836e+00,\n",
       "         -7.3712e-01,  2.2203e+00,  1.3923e+00,  1.5141e-01, -2.6996e-01,\n",
       "          2.2164e+00, -4.1337e-01,  1.2832e+00, -6.1152e-01,  1.1887e+00,\n",
       "         -2.4943e-01, -1.5051e-01,  3.9000e-01,  1.2420e+00, -9.4148e-01,\n",
       "         -1.9809e-01,  8.6274e-01, -2.0042e-01, -4.5690e-01,  7.4525e-01,\n",
       "          4.4850e-01, -2.7499e+00, -1.1713e+00, -6.3816e-01, -2.4366e-02,\n",
       "         -4.0580e-01, -1.2274e+00, -3.4743e-01, -4.5472e-01,  1.8210e-01,\n",
       "         -2.5627e+00, -2.0813e-01,  1.3244e-01,  2.5361e+00, -2.3962e-01,\n",
       "         -8.8830e-01,  2.4467e-01, -4.2976e-01]])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "4d6b3e96",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([128])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_matrix[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "f70ef2af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([True, True, True, True, True, True, True, True, True, True, True, True,\n",
       "        True, True, True, True, True, True, True, True, True, True, True, True,\n",
       "        True, True, True, True, True, True, True, True, True, True, True, True,\n",
       "        True, True, True, True, True, True, True, True, True, True, True, True,\n",
       "        True, True, True, True, True, True, True, True, True, True, True, True,\n",
       "        True, True, True, True, True, True, True, True, True, True, True, True,\n",
       "        True, True, True, True, True, True, True, True, True, True, True, True,\n",
       "        True, True, True, True, True, True, True, True, True, True, True, True,\n",
       "        True, True, True, True, True, True, True, True, True, True, True, True,\n",
       "        True, True, True, True, True, True, True, True, True, True, True, True,\n",
       "        True, True, True, True, True, True, True, True])"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pairwise_embeddings_matrix[4][:128] == embedding_matrix[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "623271b3",
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (472) must match the size of tensor b (300) at non-singleton dimension 0",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[53], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mpairwise_embeddings_matrix\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m128\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43membedding_matrix\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (472) must match the size of tensor b (300) at non-singleton dimension 0"
     ]
    }
   ],
   "source": [
    "pairwise_embeddings_matrix[0][128:] == embedding_matrix[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "4b559ad9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16384, 600])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pairwise_embeddings_matrix.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "aed3dfe6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logits:\n",
      " tensor([[ 0.2220, -0.2920, -0.3507],\n",
      "        [-0.8946,  0.2897,  0.2469],\n",
      "        [-0.2287,  0.1828,  0.5999],\n",
      "        [ 0.4986, -0.3846, -1.1352],\n",
      "        [-0.1706, -0.5710,  0.2185],\n",
      "        [-0.5212,  0.3408,  0.4124],\n",
      "        [ 0.5458, -0.6454,  0.0977],\n",
      "        [-0.6668, -0.4272, -0.6196],\n",
      "        [-0.1009, -0.0832,  0.0188],\n",
      "        [ 0.7093, -0.3336, -0.1101],\n",
      "        [-0.3309,  0.3851, -0.8583],\n",
      "        [ 0.6945, -0.6904, -0.3053],\n",
      "        [-0.1590,  0.0448, -0.0356],\n",
      "        [ 0.2188,  0.6905,  0.0321],\n",
      "        [-0.1988, -0.7908,  0.5962],\n",
      "        [-0.1032,  0.0029, -0.1704],\n",
      "        [ 0.1417, -0.0117,  0.5900],\n",
      "        [-0.8979, -0.6044, -0.8794],\n",
      "        [ 0.8302, -0.0522,  0.3876],\n",
      "        [ 0.3450, -0.7952,  0.3969],\n",
      "        [-0.6418, -0.2482,  0.1200],\n",
      "        [ 0.6546, -0.6348, -0.7047],\n",
      "        [ 0.1877,  0.1375, -1.6608],\n",
      "        [ 0.2701, -0.2194, -0.2379],\n",
      "        [-0.0050, -0.0116,  0.2958]], grad_fn=<AddmmBackward0>)\n",
      "Labels:\n",
      " tensor([2, 1, 0, 1, 0, 1, 2, 0, 0, 2, 1, 1, 1, 2, 2, 1, 0, 1, 2, 1, 2, 2, 2, 0,\n",
      "        0])\n",
      "Loss:\n",
      " 1.235495686531067\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# Define the pairwise embeddings matrix (25 x 256)\n",
    "pairwise_embeddings = torch.randn(25, 256)\n",
    "\n",
    "# Define the label matrix (25,) with class indices for each pairwise embedding\n",
    "labels = torch.randint(0, 3, (25,))  # Example labels for 3 classes\n",
    "\n",
    "# Define the linear layer (256 -> 3)\n",
    "linear_layer = nn.Linear(256, 3)\n",
    "\n",
    "# Apply the linear transformation\n",
    "logits = linear_layer(pairwise_embeddings)\n",
    "\n",
    "# Define the cross-entropy loss function\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# Calculate the loss\n",
    "loss = criterion(logits, labels)\n",
    "\n",
    "print(\"Logits:\\n\", logits)\n",
    "print(\"Labels:\\n\", labels)\n",
    "print(\"Loss:\\n\", loss.item())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ad6f2b86",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.models import MoleculeModel,mpn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "1b7a032f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.args import TrainArgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "99bc38a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['smiles']\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "arguments = [\n",
    "    '--data_path', 'Utils/dataset_nmrshiftdb2_smiles_modified_1204.pkl',\n",
    "    '--dataset_type', 'kmgcl',\n",
    "    '--smiles_columns','smiles'\n",
    "    ]\n",
    "\n",
    "pass_args=TrainArgs().parse_args(arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "439059e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_model = mpn.MPN(pass_args).to(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "fde414c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = torch.randint(0, 3, (25,))  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "c532495a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 1, 0, 0, 2, 1, 2, 1, 0, 2, 0, 0, 1, 0, 1, 0, 2, 2, 2, 2, 2, 0, 1,\n",
       "        2])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "57eb89fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "pairwise_embeddings = torch.randn(25, 256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "5a99a811",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([25, 256])"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pairwise_embeddings.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "3ddcdb4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "linear_layer = nn.Linear(256, 3)\n",
    "\n",
    "# Apply the linear transformation\n",
    "logits = linear_layer(pairwise_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "1b453fa1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([25, 3])"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "8c86d328",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding matrix shape: torch.Size([128, 300])\n",
      "Pairwise embeddings matrix shape: torch.Size([16384, 600])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# Define the embedding matrix (5 x 128)\n",
    "embedding_matrix = torch.randn(128, 300)\n",
    "\n",
    "# Get the number of rows in the embedding matrix\n",
    "num_rows = embedding_matrix.size(0)\n",
    "\n",
    "# Initialize a list to store pairwise concatenated embeddings\n",
    "pairwise_embeddings = []\n",
    "\n",
    "# Generate pairwise concatenated embeddings\n",
    "for i in range(num_rows):\n",
    "    for j in range(num_rows):  # Ensure pairs are unique\n",
    "        # Concatenate embeddings from row i and row j\n",
    "        concatenated = torch.cat((embedding_matrix[i], embedding_matrix[j]), dim=0)\n",
    "        pairwise_embeddings.append(concatenated)\n",
    "\n",
    "# Convert list of tensors to a single tensor\n",
    "pairwise_embeddings_matrix = torch.stack(pairwise_embeddings)\n",
    "\n",
    "# Print shapes to verify\n",
    "print(\"Embedding matrix shape:\", embedding_matrix.shape)\n",
    "print(\"Pairwise embeddings matrix shape:\", pairwise_embeddings_matrix.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e70b4171",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
