{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "os.chdir(\"/Users/samuel.gruffaz/Documents/PEcollab\")\n",
    "import leaspy.models.utils.OptimB as OptimB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[nan, nan],\n",
      "         [3., 4.]],\n",
      "\n",
      "        [[2., 2.],\n",
      "         [1., 2.]]])\n",
      "tensor([[   nan, 5.0000],\n",
      "        [2.8284, 2.2361]])\n"
     ]
    }
   ],
   "source": [
    "T=[[[float(\"NaN\"),float(\"NaN\")],[3.0,4.0]],[[2.0,2.0],[1.0,2.0]]]\n",
    "V=torch.tensor(T)\n",
    "print(V)\n",
    "print(torch.norm(V,p=2,dim=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[0, 3, 1]\n[2 4]\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "sam=random.sample(range(5),3)\n",
    "print(sam)\n",
    "\n",
    "B=set(np.arange(5))-set(sam)\n",
    "\n",
    "A=np.arange(5)\n",
    "print(A[list(B)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215]\n"
     ]
    }
   ],
   "source": [
    "sig=0.2\n",
    "dim=3\n",
    "\n",
    "nk=int(1/sig)\n",
    "grid1=torch.linspace(-0.5*sig,(nk+0.5)*sig,nk+1)\n",
    "L=[grid1]*dim\n",
    "T=torch.meshgrid(L)\n",
    "shape=list(T[0].shape)\n",
    "shape.append(dim)\n",
    "Per=[]\n",
    "for j in range(len(shape)):\n",
    "    Per.append(j)\n",
    "Per[-1],Per[0]=Per[0],Per[-1]\n",
    "Z=torch.zeros(shape)\n",
    "Per=tuple(Per)\n",
    "Z=Z.permute(Per)\n",
    "for j in range(dim):\n",
    "    Z[j]=T[j]\n",
    "Z=Z.permute(Per)\n",
    "\n",
    "Z=Z.reshape((-1,dim))\n",
    "\n",
    "\n",
    "#for j in range(dim):\n",
    "    #Z[]\n",
    "index=[]\n",
    "for j in range(len(Z)):\n",
    "    A=Z[j]\n",
    "    dist=torch.norm(Z-A,dim=1)\n",
    "    m=dist.min().item()\n",
    "    if m<sig:\n",
    "        index.append(j)\n",
    "X_con=Z[index]\n",
    "print(index)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],\n        [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])\ntensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 1., 2., 3., 4., 5., 6., 7.,\n        8., 9.])\n"
     ]
    }
   ],
   "source": [
    "A=torch.arange(10)\n",
    "B=torch.zeros(2,10)\n",
    "B[0]=A\n",
    "B[1]=A\n",
    "#B=B.permute(1,0)\n",
    "print(B)\n",
    "print(B.reshape((-1,)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([3.0277, 3.0277])"
      ]
     },
     "metadata": {},
     "execution_count": 4
    }
   ],
   "source": [
    "torch.std(B,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "array([[1., 1., 0., 0., 0., 0.],\n",
       "       [1., 1., 0., 0., 0., 0.],\n",
       "       [0., 0., 1., 1., 0., 0.],\n",
       "       [0., 0., 1., 1., 0., 0.],\n",
       "       [0., 0., 0., 0., 1., 1.],\n",
       "       [0., 0., 0., 0., 1., 1.]])"
      ]
     },
     "metadata": {},
     "execution_count": 35
    }
   ],
   "source": [
    "from scipy.linalg import block_diag\n",
    "np.kron(np.eye(3,dtype=int),np.ones((2,2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "0.17320507764816284\n"
     ]
    }
   ],
   "source": [
    "dis=torch.norm(Z,dim=1)\n",
    "print(dis.min().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "D={\"B\":1.4,\"D\":2.3}\n",
    "\n",
    "print(\"B\" in D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[3., 4.],\n",
       "        [2., 2.],\n",
       "        [1., 2.]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "V[V==V].reshape(3,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "T=[[[1.0,float(\"NaN\")],[3.0,4.0]],[[2.0,2.0],[1.0,2.0]]]\n",
    "V=torch.tensor(T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[False,  True],\n",
      "        [ True,  True]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[3., 4.],\n",
       "        [2., 2.],\n",
       "        [1., 2.]])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Z=(V==V).all(axis=2)\n",
    "print(Z)\n",
    "\n",
    "V[Z]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "T=[[[1.0,float(\"NaN\")],[float(\"NaN\"),float(\"NaN\")]],[[2.0,2.0],[1.0,2.0]]]\n",
    "V=torch.tensor(T)\n",
    "XT=V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "output_type": "error",
     "ename": "NameError",
     "evalue": "name 'mask' is not defined",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-39-acc60b8fac84>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mX1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mY1\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mOptimB\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfiltre_nan_inhomogene\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'mask' is not defined"
     ]
    }
   ],
   "source": [
    "X1,Y1=OptimB.filtre_nan_inhomogene(V,V,mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ True, False],\n",
       "        [False, False]])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(XT!=XT).any(axis=2)*(XT==XT).any(axis=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[3. 2.]\n",
      "  [3. 2.]]\n",
      "\n",
      " [[1. 2.]\n",
      "  [1. 2.]]]\n"
     ]
    }
   ],
   "source": [
    "S=[[3.0,2.0],[1.0,2.0]]\n",
    "g=torch.tensor(S).unsqueeze(-2).numpy()\n",
    "\n",
    "Z=torch.ones(2).reshape((-1,1)).numpy()\n",
    "\n",
    "print(np.transpose(np.kron(Z,g),axes=(0,1,2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[4., 4.]], dtype=float32)"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(g,axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([8.3000, 8.2000])\n"
     ]
    }
   ],
   "source": [
    "D=torch.tensor(S)\n",
    "\n",
    "V=torch.tensor([2.1,2.0])\n",
    "\n",
    "print(torch.matmul(V,D))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def FiltreNanHomogène(XT,Y):\n",
    "    \"\"\"\n",
    "    Prend en entrée XT (nb_patient,nb_visit_max,dim) et retourne X sous la forme (nb_visit,dim)\n",
    "\n",
    "    Si un vecteur contient un Nan dans ses coordonnées on le retire\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "\n",
    "    Select=((XT==XT).all(axis=2))*(Y==Y).all(axis=2)#fonctionne bien voir notebook test pour se convaincre\n",
    "    \n",
    "    return XT[Select],Y[Select]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "X,Y=FiltreNanHomogène(V,V)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Sub_sampling(X,k):\n",
    "    \"\"\"\n",
    "    Prend X le tensor (nb_visite,dim) et sélectionne k points bien espacé renvoyé dans un tensor (k,dim)\n",
    "\n",
    "\n",
    "    \"\"\"\n",
    "    Center,index=kmeans_plus_plus(X.numpy(), k)\n",
    "    return torch.from_numpy(Center),index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[3., 4.],\n",
       "         [2., 2.],\n",
       "         [1., 2.]],\n",
       "\n",
       "        [[3., 4.],\n",
       "         [2., 2.],\n",
       "         [1., 2.]],\n",
       "\n",
       "        [[3., 4.],\n",
       "         [2., 2.],\n",
       "         [1., 2.]]])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.repeat(3,1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Matrix(X,Xgrand,meta_settings):\n",
    "    \"\"\"\n",
    "    X est la donnée des points de controles un tensor de la forme (k,nb_dim) (k nombre de visite après subsampling), kernelname le nom du noyau à utiliser\n",
    "    cette fonction renvoie la matrice K_X=(k(x_i,x_j)) i <nb_visit+1, j<k+1\n",
    "    On a Xgrand (nb_visit,nb_dim) les points de controle sans subsampling\n",
    "    \"\"\"\n",
    "    kernelname=meta_settings[\"kernelname\"]\n",
    "    sigma=meta_settings[\"sigma\"]\n",
    "    k=len(X)\n",
    "    nb_visit=len(Xgrand)\n",
    "    \n",
    "    \n",
    "    if kernelname==\"RBF\":#le calcul est fait sans approximations\n",
    "        sigma=meta_settings[\"sigma\"]\n",
    "        \n",
    "\n",
    "        PA1=Xgrand.repeat(k,1,1)\n",
    "        PA2=X.repeat(nb_visit,1,1).permute(1,0,2)\n",
    "\n",
    "        PA3=PA1-PA2\n",
    "\n",
    "        K_X=torch.exp(-torch.norm(PA3,dim=2)**2/(2*sigma**2))\n",
    "    else:\n",
    "        raise ValueError(\"Le nom de noyau est mauvais ! \")\n",
    "\n",
    "    return K_X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000, 0.5353],\n",
       "        [0.5353, 1.0000],\n",
       "        [0.3679, 0.8825]], dtype=torch.float64)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Matrix(X,A,{\"sigma\":2.0,\"kernelname\":\"RBF\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[3., 4.],\n",
      "        [2., 2.]], dtype=torch.float64)\n",
      "tensor([[3., 2.],\n",
      "        [4., 2.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "X1=X.repeat(2,1,1).permute(1,0,2)\n",
    "\n",
    "X2=A.repeat(3,1,1)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "print(A)\n",
    "\n",
    "print(A.transpose(0,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dist(data, centers):\n",
    "    distance = np.sum((np.array(centers) - data[:, None, :])**2, axis = 2)\n",
    "    return distance\n",
    "def kmeans_plus_plus(X, k):\n",
    "    '''Initialize one point at random.\n",
    "    loop for k - 1 iterations:\n",
    "        Next, calculate for each point the distance of the point from its nearest center. Sample a point with a \n",
    "        probability proportional to the square of the distance of the point from its nearest center.'''\n",
    "    centers = []\n",
    "    index=[]\n",
    "    \n",
    "    # Sample the first point\n",
    "    initial_index = np.random.choice(range(X.shape[0]), )\n",
    "    index.append(initial_index)\n",
    "    centers.append(X[initial_index, :].tolist())\n",
    "    \n",
    "    print('max: ', np.max(np.sum((X - np.array(centers))**2)))\n",
    "    \n",
    "    # Loop and select the remaining points\n",
    "    for i in range(k - 1):\n",
    "        print(i)\n",
    "        distance = dist(X, np.array(centers))\n",
    "        \n",
    "        if i == 0:\n",
    "            pdf = distance/np.sum(distance)\n",
    "            indexcour=np.random.choice(range(X.shape[0]), replace = False, p = pdf.flatten())\n",
    "            centroid_new = X[indexcour]\n",
    "            index.append(indexcour)\n",
    "        else:\n",
    "            # Calculate the distance of each point from its nearest centroid\n",
    "            dist_min = np.min(distance, axis = 1)\n",
    "            \n",
    "            pdf = dist_min/np.sum(dist_min)\n",
    "# Sample one point from the given distribution\n",
    "            indexcour=np.random.choice(range(X.shape[0]), replace = False, p = pdf)\n",
    "            centroid_new = X[indexcour]\n",
    "            index.append(indexcour)\n",
    "            \n",
    "\n",
    "        centers.append(centroid_new.tolist())\n",
    "        \n",
    "    return np.array(centers),index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "max:  13.0\n",
      "0\n",
      "tensor([[3., 4.],\n",
      "        [2., 2.]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "A,B=Sub_sampling(X,2)\n",
    "\n",
    "print(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [],
   "source": [
    "def TransformationB(W,Control,meta_settings):\n",
    "    \"\"\"\n",
    "    Prend en entrée la matrice des poids, et les points de contrôles \"Control\" ainsi que meta_settings pour avoir des\n",
    "    informations sur le noyau, W de forme (nb_controle,nb_features). On renvoie la fonction associée pour update B.\n",
    "\n",
    "    #les poids des features i, c'est la colonne i de W\n",
    "    \"\"\"\n",
    "    if meta_settings[\"kernelname\"]==\"RBF\":\n",
    "        sigma=meta_settings[\"sigma\"]\n",
    "        def function(x):\n",
    "            \"\"\"\n",
    "            x doit être de la forme (n1,n2,n_dim) ou (n_1,n_dim)\n",
    "\n",
    "            \"\"\"\n",
    "            \n",
    "            nb_dim=len(x.shape)\n",
    "            if nb_dim==3:\n",
    "                Control1=Control.repeat(x.shape[0],x.shape[1],1,1)\n",
    "                x1=x.unsqueeze(2)\n",
    "                \n",
    "                \n",
    "                KK=torch.exp(-torch.norm(Control1-x1,dim=nb_dim)**2/(2*sigma**2))\n",
    "                print(KK)\n",
    "                #fonctionne pour W dim 1\n",
    "                W1=W.double()\n",
    "                Fin=torch.matmul(KK,W1)\n",
    "                return Fin\n",
    "            elif nb_dim==2:#Si on prend une liste\n",
    "                Control1=Control.repeat(x.shape[0],1,1)\n",
    "                x1=x.unsqueeze(1)\n",
    "                print(\"x1\")\n",
    "                print(x1)\n",
    "                print(\"Control1\")\n",
    "                print(Control1)\n",
    "                \n",
    "                KK=torch.exp(-torch.norm(Control1-x1,dim=nb_dim)**2/(2*sigma**2))\n",
    "                print(KK)\n",
    "                #à vérifier le reset\n",
    "                W1=W.double()\n",
    "                Fin=torch.matmul(KK,W1)\n",
    "                return Fin# de même dimension que x\n",
    "            else:\n",
    "                raise ValueError(\"La valeur de x a une mauvaise shape, les formats acceptés sont (n,dim) ou (n1,n2 ,dim) \")\n",
    "            \n",
    "\n",
    "\n",
    "           \n",
    "           \n",
    "        return function\n",
    "    else:\n",
    "        raise ValueError(\"Le nom de noyau est mauvais ! \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[1.0000, 2.5000],\n",
      "         [3.0000, 4.0000]],\n",
      "\n",
      "        [[2.0000, 2.0000],\n",
      "         [1.0000, 2.0000]]])\n",
      "tensor([[3., 4.],\n",
      "        [2., 2.]], dtype=torch.float64)\n",
      "tensor([[[0.4578, 0.8553],\n",
      "         [1.0000, 0.5353]],\n",
      "\n",
      "        [[0.5353, 1.0000],\n",
      "         [0.3679, 0.8825]]], dtype=torch.float64)\n",
      "tensor([[[3.9395, 4.3974],\n",
      "         [4.6058, 5.6058]],\n",
      "\n",
      "        [[4.6058, 5.1410],\n",
      "         [3.7511, 4.1190]]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "Control=A\n",
    "meta_settings={\"sigma\":2.0,\"kernelname\":\"RBF\"}\n",
    "W=torch.tensor([[3.0,4.0],[3.0,3.0]])\n",
    "#W=torch.tensor([2.0,3.0])\n",
    "#les poids des features i, c'est la colonne i de W\n",
    "fonc=TransformationB(W,Control,meta_settings)\n",
    "\n",
    "print(V)\n",
    "print(Control)\n",
    "\n",
    "print(fonc(V))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "T=[[[1.0,2.5],[3.0,4.0]],[[2.0,2.0],[1.0,2.0]]]\n",
    "V=torch.tensor(T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[3., 4.],\n",
      "          [2., 2.]],\n",
      "\n",
      "         [[3., 4.],\n",
      "          [2., 2.]]],\n",
      "\n",
      "\n",
      "        [[[3., 4.],\n",
      "          [2., 2.]],\n",
      "\n",
      "         [[3., 4.],\n",
      "          [2., 2.]]]], dtype=torch.float64)\n",
      "tensor([[[[1.0000, 2.5000]],\n",
      "\n",
      "         [[3.0000, 4.0000]]],\n",
      "\n",
      "\n",
      "        [[[2.0000, 2.0000]],\n",
      "\n",
      "         [[1.0000, 2.0000]]]])\n",
      "tensor([[[[ 2.0000,  1.5000],\n",
      "          [ 1.0000, -0.5000]],\n",
      "\n",
      "         [[ 0.0000,  0.0000],\n",
      "          [-1.0000, -2.0000]]],\n",
      "\n",
      "\n",
      "        [[[ 1.0000,  2.0000],\n",
      "          [ 0.0000,  0.0000]],\n",
      "\n",
      "         [[ 2.0000,  2.0000],\n",
      "          [ 1.0000,  0.0000]]]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "C=A.repeat(2,2,1,1)\n",
    "print(C)\n",
    "print(V.unsqueeze(2))\n",
    "print(C-V.unsqueeze(2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}