{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "feeb637a-2e7e-4136-9393-c6781f297942",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,\n",
       "        4])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "torch.arange(5).repeat(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "49506f1f-e3b9-4e17-926c-8a37a9557f20",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "cos = nn.CosineSimilarity(dim=-1)\n",
    "kl = nn.KLDivLoss(reduction='batchmean')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "689170f6-6be9-4aec-9f4a-03e50dab15cd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "c = nn.Softmax(dim=-1)(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "4865e43b-db26-4a04-8818-23e958b566eb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.1657, 0.2451, 0.0601, 0.0816, 0.1520, 0.1082, 0.0199, 0.1094, 0.0126,\n",
       "        0.0455])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[0][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "0504bd5c-074a-4c71-9a15-eef92679ae80",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 75, 10])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "softmax_pre = torch.softmax(a, dim=-1)\n",
    "log_softmax_pre = torch.log(softmax_pre)\n",
    "torch.mul(log_softmax_pre, log_softmax_pre ).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "3e246e20-8d49-4e49-84b2-7967a31ce7b5",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 3",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[38], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m a \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m8\u001b[39m,\u001b[38;5;241m75\u001b[39m,\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m      2\u001b[0m b \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m3\u001b[39m,\u001b[38;5;241m5\u001b[39m)\n\u001b[0;32m----> 3\u001b[0m \u001b[43mcos\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/distance.py:87\u001b[0m, in \u001b[0;36mCosineSimilarity.forward\u001b[0;34m(self, x1, x2)\u001b[0m\n\u001b[1;32m     86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x1: Tensor, x2: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m---> 87\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcosine_similarity\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx1\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 3"
     ]
    }
   ],
   "source": [
    "a = torch.randn(8,75,10)\n",
    "b = torch.randn(3,5)\n",
    "cos(a.unsqueeze(1), b.unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "f4568042-98a6-45d0-bd4a-39ceaa6d10c6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "softmax_pre = torch.softmax(a, dim=-1)\n",
    "log_softmax_pre = torch.log(softmax_pre)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "a472d9ad-b195-4741-bda6-d27a7ea46ce4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.7897)\n",
      "tensor(0.7587)\n",
      "tensor(0.7540)\n",
      "tensor(0.7334)\n",
      "tensor(0.8005)\n",
      "tensor(0.8314)\n",
      "tensor(0.7726)\n",
      "tensor(0.8357)\n",
      "tensor(0.7852)\n",
      "tensor(0.8920)\n",
      "tensor(0.7987)\n",
      "tensor(0.8912)\n",
      "tensor(0.8435)\n",
      "tensor(0.9417)\n",
      "tensor(0.7888)\n",
      "tensor(0.8173)\n",
      "tensor(0.8008)\n",
      "tensor(0.8648)\n",
      "tensor(0.8014)\n",
      "tensor(0.8420)\n",
      "tensor(0.9220)\n",
      "tensor(0.7495)\n",
      "tensor(0.8717)\n",
      "tensor(0.8525)\n",
      "tensor(0.8804)\n",
      "tensor(0.8658)\n",
      "tensor(0.8500)\n",
      "tensor(0.8657)\n",
      "tensor(0.7424)\n",
      "tensor(0.7992)\n",
      "tensor(0.8431)\n",
      "tensor(0.9279)\n",
      "tensor(0.8332)\n",
      "tensor(0.8266)\n",
      "tensor(0.8412)\n",
      "tensor(0.8329)\n",
      "tensor(0.8953)\n",
      "tensor(0.8543)\n",
      "tensor(0.8887)\n",
      "tensor(0.8614)\n",
      "tensor(0.8551)\n",
      "tensor(0.9913)\n",
      "tensor(0.8107)\n",
      "tensor(0.8458)\n",
      "tensor(0.8576)\n",
      "tensor(0.8921)\n",
      "tensor(0.8034)\n",
      "tensor(0.8509)\n",
      "tensor(0.8998)\n",
      "tensor(0.7955)\n",
      "tensor(0.9260)\n",
      "tensor(0.8790)\n",
      "tensor(0.8905)\n",
      "tensor(0.8682)\n",
      "tensor(0.9592)\n",
      "tensor(0.8994)\n"
     ]
    }
   ],
   "source": [
    "distill_loss = 0\n",
    "for i in range(a.shape[0]):\n",
    "    for j in range(a.shape[0]):\n",
    "        if i != j:\n",
    "            distill_loss += kl(log_softmax_pre[i],softmax_pre[j])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "8fe659b9-7db6-47c2-988c-c084f7c76be6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-1.2107e-08)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kl(log_softmax_pre,softmax_pre)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "f9375008-60cf-4d7b-9a56-17d31a9d350a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 2.1924, -0.9271,  0.7655,  0.5822,  1.8794],\n",
       "         [ 1.4008,  0.5272, -0.0524,  1.6675, -0.7865],\n",
       "         [ 0.6004,  0.6094,  0.3761,  0.6691,  1.3732]]])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b.unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "1b1a95bf-05cd-4f9b-b92f-690375af6b3f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kl(a, a).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "f6e480e6-13bd-4bc1-9d37-6b3107666d24",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'F' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[39], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mF\u001b[49m\u001b[38;5;241m.\u001b[39mcross_entropy(a\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m), a\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m))\n",
      "\u001b[0;31mNameError\u001b[0m: name 'F' is not defined"
     ]
    }
   ],
   "source": [
    "F.cross_entropy(a.unsqueeze(0), a.unsqueeze(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c088194-602e-4f0f-8221-3514a8233a58",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
