{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e9f23061-fc15-4ddd-a0db-72cc1fd966c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from typing import Union"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ebabb159-2c00-41fa-87b8-f9238b4e4e16",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.        , 7.68114575, 5.19615242],\n",
       "       [7.68114575, 0.        , 5.65685425],\n",
       "       [5.19615242, 5.65685425, 0.        ]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = np.array([[1,2,3],[8,5,2],[4,5,6]])\n",
    "diff = a[:,None] - a[None,:]\n",
    "np.linalg.norm(diff,axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1ab532bf-b9f3-4f8e-a590-928496d39aa2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 7.6811, 5.1962],\n",
       "        [7.6811, 0.0000, 5.6569],\n",
       "        [5.1962, 5.6569, 0.0000]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.tensor([[1,2,3],[8,5,2],[4,5,6]],dtype=torch.float32)\n",
    "a.detach()\n",
    "diff = a[:,None] - a[None,:]\n",
    "torch.norm(diff,dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8ddb6d34-2894-4825-90fa-e4985cd2b0c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cosine_sim(vecs:Union[np.ndarray, torch.Tensor],rm_mean=False)->np.ndarray:\n",
    "    '''\n",
    "    vecs: is a N*D matrix contains N vectors of D dimension\n",
    "    return a distance matrix sim, sim[i,j] is the cosine similarity between vector i and j\n",
    "    '''\n",
    "    if isinstance(vecs,torch.Tensor):\n",
    "        vecs = vecs.detach()\n",
    "        if rm_mean:\n",
    "            vecs -= torch.mean(vecs,dim=0)\n",
    "        sim = torch.nn.functional.cosine_similarity(vecs[:,None],vecs[None,:],dim=-1)\n",
    "        return sim.detach().cpu().numpy()\n",
    "    elif isinstance(vecs,np.ndarray):\n",
    "        if rm_mean:\n",
    "            vecs -= np.mean(vecs,dim=0)\n",
    "        norm = np.linalg.norm(vecs,axis=-1)\n",
    "        norm_prod = norm[None,:]*norm[:,None]\n",
    "        dot = np.matmul(vecs,vecs.transpose())\n",
    "        return dot/(norm_prod + 1e-6)\n",
    "    else:\n",
    "        raise TypeError(\"input must be an ndarray or tensor\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "fd31f4b8-6ee6-49be-a468-2aa2290bba34",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.9999999 , 0.99999994, 0.        ],\n",
       "       [0.99999994, 0.99999994, 0.        ],\n",
       "       [0.        , 0.        , 0.99999976]], dtype=float32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch_vecs = torch.tensor([[1,2,3,4],[2,4,6,8],[-1,1,1,-1]],dtype=torch.float32)\n",
    "np_vecs = np.array([[1,2,3,4],[2,4,6,8],[-1,1,1,-1]],dtype=np.float32)\n",
    "get_cosine_sim(np_vecs,rm_mean=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "287fd56b-d160-464d-a60f-ce7f75a31d07",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_space_align(covs:Union[np.ndarray, torch.Tensor])->np.ndarray:\n",
    "    '''\n",
    "    covs: is a N*D*D matrix contains N covariance matrices\n",
    "    sim[i,j] = trace(covs[i]*covs[j])/(|covs[i]|*|covs[j]|)\n",
    "    return a distance matrix sim, sim[i,j] is the cosine similarity between cov i and j\n",
    "    '''\n",
    "    if isinstance(covs,torch.Tensor):\n",
    "        covs = covs.detach()\n",
    "        mat_prod = torch.einsum(\"imn,jnm->ij\",covs,covs)\n",
    "        norm = torch.linalg.matrix_norm(covs)\n",
    "        norm_prod = norm[None,:]*norm[:,None]\n",
    "        sim = mat_prod/(norm_prod + 1e-6)\n",
    "        return sim.detach().cpu().numpy()\n",
    "    elif isinstance(covs,np.ndarray):\n",
    "        mat_prod = np.einsum(\"imn,jnm->ij\",covs,covs)\n",
    "        # norm is a vecotor of shape (N,)\n",
    "        norm = np.linalg.norm(covs,axis=(1,2))\n",
    "        norm_prod = norm[None,:]*norm[:,None]\n",
    "        sim = mat_prod/(norm_prod + 1e-6)\n",
    "        return sim\n",
    "    else:\n",
    "        raise TypeError(\"input must be an ndarray or tensor\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "ba87e6a7-85ae-4fbf-b7f7-ef8c3c278a2d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.9999999 , 0.44721356, 0.9999999 ],\n",
       "       [0.44721356, 1.        , 0.44721356],\n",
       "       [0.9999999 , 0.44721356, 1.        ]], dtype=float32)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cov0 = [[1,2],[2,1]]\n",
    "cov1 = [[3,0],[0,3]]\n",
    "cov2 = [[2,4],[4,2]]\n",
    "torch_covs = torch.tensor([cov0,cov1,cov2],dtype=torch.float32)\n",
    "get_space_align(torch_covs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "999378b7-488c-49e9-a4f4-bc74ff2f360d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.9999999 , 0.44721356, 0.9999999 ],\n",
       "       [0.44721356, 1.        , 0.44721356],\n",
       "       [0.9999999 , 0.44721356, 1.        ]], dtype=float32)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np_covs = np.array([cov0,cov1,cov2],dtype=np.float32)\n",
    "get_space_align(np_covs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "4007a9c1-8470-4c35-b2cd-dad9d13f8f14",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cov_traces(covs:Union[np.ndarray, torch.Tensor])->np.ndarray:\n",
    "    '''\n",
    "    covs: is a N*D*D matrix contains N covariance matrices\n",
    "    sim[i,j] = trace(covs[i]*covs[j])/(|covs[i]|*|covs[j]|)\n",
    "    return a distance matrix sim, sim[i,j] is the cosine similarity between cov i and j\n",
    "    '''\n",
    "    if isinstance(covs,torch.Tensor):\n",
    "        covs = covs.detach()\n",
    "        trace = torch.einsum(\"ijj->i\",covs)\n",
    "        return trace.detach().cpu().numpy()\n",
    "    elif isinstance(covs,np.ndarray):\n",
    "        trace = np.einsum(\"ijj->i\",covs)\n",
    "        return trace\n",
    "    else:\n",
    "        raise TypeError(\"input must be an ndarray or tensor\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "533366fc-45e0-4c29-937f-9f2adcf149f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2., 6., 4.], dtype=float32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_cov_traces(torch_covs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ec71f86-5dd5-4775-bd8c-f92be81ef995",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
