{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bec1f172",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/xyuqing/Dropbox (MIT)/MIT/PhD/Notebooks/Groups Symm Break/2-layer-student-teacher\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import sys\n",
    "import os\n",
    "\n",
    "sys.path.append(os.path.dirname(os.getcwd()))\n",
    "\n",
    "from src.models import *\n",
    "from src.loss_ReLU import *\n",
    "from src.equi_test import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b78be3c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(15)\n",
    "# [(2,1),(3,2)], 3\n",
    "# 56 equi spurious minimal holds even when relax\n",
    "# 54 weird oscillatory stuff, multi equibreak, nonequi final local minima\n",
    "# 44 spurious nonequi minima single equibreak\n",
    "# 30 double peaked equi error, kinda bad final local minima\n",
    "# 20 another ending in spurious nonequi minima with multi equi breaks\n",
    "# 18 spurious minima if equi, one set of 3 basically average, learned switches averaged set of 3\n",
    "# 15 spurious minima if equi, multiple equi break when relax, ends in spurious nonequi minima\n",
    "# 8 spurious minima if equi, permute 3,3 when relax\n",
    "# 7 bad loss if equi, almost perfect when relax, but look very different\n",
    "# 2 spurious minima if equi, permute 3 when relax\n",
    "symm_blocks = [(2,1),(3,2)]\n",
    "# invar_nodes = 3\n",
    "# symm_blocks = [(3,1),]\n",
    "invar_nodes = 3\n",
    "d = 8\n",
    "student = EquiBlocks(d, symm_blocks, invar_nodes)\n",
    "\n",
    "teacher = EquiBlocks(d, symm_blocks, invar_nodes)\n",
    "\n",
    "# symm_blocks = [(2,0),(3,0)]\n",
    "# invar_nodes = 11\n",
    "# total_nodes, equi_params, invar_params = gen_model(d,symm_blocks, invar_nodes, requires_grad = True)\n",
    "# W = create_W(d,symm_blocks, invar_nodes, total_nodes, equi_params, invar_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ec9de669",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xyuqing/envs/e3nn/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)\n",
      "  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(8.8075, grad_fn=<AddBackward0>)\n",
      "tensor(0.1233, grad_fn=<AddBackward0>)\n",
      "tensor(0.1110, grad_fn=<AddBackward0>)\n",
      "tensor(0.1076, grad_fn=<AddBackward0>)\n",
      "tensor(0.1038, grad_fn=<AddBackward0>)\n",
      "tensor(0.0985, grad_fn=<AddBackward0>)\n",
      "tensor(0.0947, grad_fn=<AddBackward0>)\n",
      "tensor(0.0936, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n",
      "tensor(0.0934, grad_fn=<AddBackward0>)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0.0934, grad_fn=<AddBackward0>),\n",
       " tensor(0.0085, grad_fn=<MeanBackward0>))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "opt = torch.optim.SGD(student.parameters(),lr = 1e-1)\n",
    "\n",
    "print()\n",
    "iters = 20000\n",
    "for i in range(iters):\n",
    "    loss = F(student.create_W(),teacher.create_W())\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "    opt.zero_grad()\n",
    "    if i%500==0:\n",
    "        print(loss)\n",
    "        \n",
    "F(student.create_W(),teacher.create_W()), torch.mean(torch.abs(grad_F(student.create_W(),teacher.create_W())))\n",
    "\n",
    "# opt = torch.optim.SGD([W],lr = 1e-1)\n",
    "\n",
    "# iters = 20000\n",
    "# for i in range(iters):\n",
    "#     loss = F(W,V)\n",
    "# #     loss = torch.norm(W)\n",
    "#     loss.backward()\n",
    "#     opt.step()\n",
    "#     opt.zero_grad()\n",
    "#     if i%500==0:\n",
    "#         print(loss)\n",
    "        \n",
    "# F(W,V), torch.mean(torch.abs(grad_F(W,V)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0201a343",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x717eb01a00a0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAFhCAYAAADQncj9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAXxUlEQVR4nO3df4xW9b0n8M8wAw+gw9QfZYCAOtvLBhW16KjXYqrGia7x525i6i4mhGa1q0ORkm2VxB/rz1GjdqJ1QblRaapi9w9/1I0YLq1w3YoIlEZi44/oraN0hpq1M/zYDtyZs3/cMNdRbJl6nu8zz3ler+T8MWcOfD4nhHfenOcZnrosy7IAAEhkTKUXAABqi/IBACSlfAAASSkfAEBSygcAkJTyAQAkpXwAAEk1VHqBzxscHIzt27dHY2Nj1NXVVXodqElZlsXOnTtj2rRpMWZMdfwbRXZAZY0kN0Zd+di+fXvMmDGj0msAEdHV1RXTp0+v9BoHRXbA6HAwuTHqykdjY2NERPzz//nfMenQQ5LMvPb1JGOGdF4yK+1AGKGdO3fGzJkzh/4+VoP9u2596+1ke6898awkc/a75Hf/lHQejMRIcmPUlY/9j0snHXpITGo8NMnMcROTjBkyadKktAPhb1RNL1/s37WxsTEaE/0dmzimPsmc/WQH1eBgcqM6XswFAApD+QAAklI+AICklA8AICnlAwBISvkAAJJSPgCApJQPACCpspWPhx9+OI455pgYP358nH766bFx48ZyjQIKQm5AbShL+XjmmWdiyZIlccstt8SWLVvipJNOivPPPz927NhRjnFAAcgNqB1lKR8PPPBAXHXVVbFgwYI47rjjYvny5TFx4sR47LHHvnBtf39/9PX1DTuA2jOS3IiQHVDNci8fe/fujc2bN0dbW9u/DRkzJtra2uK11177wvUdHR3R1NQ0dPhUSqg9I82NCNkB1Sz38vHJJ5/EwMBANDc3Dzvf3Nwc3d3dX7h+6dKl0dvbO3R0dXXlvRIwyo00NyJkB1Szin+qbalUilKpVOk1gCojO6B65f7k48gjj4z6+vro6ekZdr6npyemTJmS9zigAOQG1Jbcy8e4cePilFNOibVr1w6dGxwcjLVr18YZZ5yR9zigAOQG1JayvOyyZMmSmD9/frS2tsZpp50WnZ2dsXv37liwYEE5xgEFIDegdpSlfHznO9+JP/7xj3HzzTdHd3d3fPOb34zVq1d/4c1kAPvJDagdZXvD6cKFC2PhwoXl+u2BApIbUBt8tgsAkJTyAQAkpXwAAEkpHwBAUsoHAJCU8gEAJFXxz3b5Mte+HjFuYppZj30rzZz99qYdBzVl7YlnxcQx9Ulmnff+piRzoGg8+QAAklI+AICklA8AICnlAwBISvkAAJJSPgCApJQPACAp5QMASEr5AACSUj4AgKSUDwAgKeUDAEhK+QAAklI+AICklA8AICnlAwBISvkAAJJSPgCApJQPACAp5QMASEr5AACSUj4AgKSUDwAgKeUDAEhK+QAAklI+AICklA8AIKmGSi/wZTovmRWTJk1KMmtvkin/Zsxr/yvxxIjpnT3JZ6b04TMLK70Co8Qlv/unZNmRWsP//X3ymddv2pd8Zkp3nvd3lV6hJnnyAQAkpXwAAEkpHwBAUsoHAJCU8gEAJKV8AABJKR8AQFLKBwCQlPIBACSlfAAASeVePjo6OuLUU0+NxsbGmDx5clx22WXx9ttv5z0GKBC5AbUl9/Kxbt26aG9vjw0bNsSaNWti3759cd5558Xu3bvzHgUUhNyA2pL7B8utXr162NdPPPFETJ48OTZv3hzf/va38x4HFIDcgNpS9k+17e3tjYiIww8//IDf7+/vj/7+/qGv+/r6yr0SMMr9tdyIkB1Qzcr6htPBwcFYvHhxzJ07N2bPnn3Aazo6OqKpqWnomDFjRjlXAka5g8mNCNkB1ays5aO9vT22bdsWq1at+tJrli5dGr29vUNHV1dXOVcCRrmDyY0I2QHVrGwvuyxcuDBefPHFWL9+fUyfPv1LryuVSlEqlcq1BlBFDjY3ImQHVLPcy0eWZfH9738/nn322XjllVeipaUl7xFAwcgNqC25l4/29vZ46qmn4vnnn4/Gxsbo7u6OiIimpqaYMGFC3uOAApAbUFtyf8/HsmXLore3N84+++yYOnXq0PHMM8/kPQooCLkBtaUsL7sAjITcgNris10AgKSUDwAgKeUDAEhK+QAAklI+AICklA8AIKmyf6otXzS9syf5zI8WNyedV4l7hKK7ftO+5DPvaR2bdF4l7pH0PPkAAJJSPgCApJQPACAp5QMASEr5AACSUj4AgKSUDwAgKeUDAEhK+QAAklI+AICklA8AICnlAwBISvkAAJJSPgCApJQPACAp5QMASEr5AACSUj4AgKSUDwAgKeUDAEhK+QAAklI+AICklA8AICnlAwBISvkAAJJSPgCApJQPACAp5QMASKqh0guQxvTOnqTzPlrcnHTeYNJpUDuu37Qv6bx7WscmnfcvSaexnycfAEBSygcAkJTyAQAkpXwAAEkpHwBAUsoHAJCU8gEAJKV8AABJKR8AQFJlLx9333131NXVxeLFi8s9CigIuQHFVtby8cYbb8QjjzwSJ554YjnHAAUiN6D4ylY+du3aFfPmzYsVK1bEYYcdVq4xQIHIDagNZSsf7e3tceGFF0ZbW9tfvK6/vz/6+vqGHUBtOtjciJAdUM3K8qm2q1atii1btsQbb7zxV6/t6OiIW2+9tRxrAFVkJLkRITugmuX+5KOrqyuuu+66ePLJJ2P8+PF/9fqlS5dGb2/v0NHV1ZX3SsAoN9LciJAdUM1yf/KxefPm2LFjR5x88slD5wYGBmL9+vXxk5/8JPr7+6O+vn7oe6VSKUqlUt5rAFVkpLkRITugmuVePs4999x48803h51bsGBBzJo1K66//vovBAiA3IDaknv5aGxsjNmzZw87d8ghh8QRRxzxhfMAEXIDao3/4RQASKosP+3yea+88kqKMUCByA0oLk8+AICklA8AICnlAwBISvkAAJJSPgCApJQPACAp5QMASCrJ//PBcK+u+F7ymS1j9ySfCeTrjsF/TD5zxeRLk877n/+cdFxcfXjaefwrTz4AgKSUDwAgKeUDAEhK+QAAklI+AICklA8AICnlAwBISvkAAJJSPgCApJQPACAp5QMASEr5AACSUj4AgKSUDwAgKeUDAEhK+QAAklI+AICklA8AICnlAwBISvkAAJJSPgCApJQPACAp5QMASEr5AACSUj4AgKSUDwAgKeUDAEiqodIL1KKWsXuSz/xg38Sk8ypxj1B0KyZfmnzmVTueTzqvEvdIep58AABJKR8AQFLKBwCQlPIBACSlfAAASSkfAEBSygcAkJTyAQAkpXwAAEkpHwBAUmUpHx9//HFceeWVccQRR8SECRPihBNOiE2bNpVjFFAQcgNqR+6f7fLpp5/G3Llz45xzzomXXnopvv71r8e7774bhx12WN6jgIKQG1Bbci8f99xzT8yYMSMef/zxoXMtLS1fen1/f3/09/cPfd3X15f3SsAoN9LciJAdUM1yf9nlhRdeiNbW1rj88stj8uTJMWfOnFixYsWXXt/R0RFNTU1Dx4wZM/JeCRjlRpobEbIDqlnu5eP999+PZcuWxcyZM+Pll1+Oa665JhYtWhQrV6484PVLly6N3t7eoaOrqyvvlYBRbqS5ESE7oJrl/rLL4OBgtLa2xl133RUREXPmzIlt27bF8uXLY/78+V+4vlQqRalUynsNoIqMNDciZAdUs9yffEydOjWOO+64YeeOPfbY+PDDD/MeBRSE3IDaknv5mDt3brz99tvDzr3zzjtx9NFH5z0KKAi5AbUl9/Lxgx/8IDZs2BB33XVXvPfee/HUU0/Fo48+Gu3t7XmPAgpCbkBtyb18nHrqqfHss8/G008/HbNnz47bb789Ojs7Y968eXmPAgpCbkBtyf0NpxERF110UVx00UXl+K2BgpIbUDt8tgsAkJTyAQAkpXwAAEkpHwBAUsoHAJCU8gEAJFWWH7Vl9GkZuyfpvA/2TUw676gJScdBzVgx+dKk867a8XzSeRH/LfE8Ijz5AAASUz4AgKSUDwAgKeUDAEhK+QAAklI+AICklA8AICnlAwBISvkAAJJSPgCApJQPACAp5QMASEr5AACSUj4AgKSUDwAgKeUDAEhK+QAAklI+AICklA8AICnlAwBISvkAAJJSPgCApJQPACAp5QMASEr5AACSUj4AgKSUDwAgqYZKL/Bl3vmv/yUOHTtq1/tKTvjxvZVeoey+kXjeRw3/Lum8plJ90nkREfV1dclm7RtMNip3i57/XYydcGil1yiL3j37Kr1C2f1j/H3Seaueuz/pvLFHz0o6LyLiH+pak8z5f7t3HvS1nnwAAEkpHwBAUsoHAJCU8gEAJKV8AABJKR8AQFLKBwCQlPIBACSlfAAASSkfAEBSuZePgYGBuOmmm6KlpSUmTJgQ3/jGN+L222+PLMvyHgUUhNyA2pL7h6fcc889sWzZsli5cmUcf/zxsWnTpliwYEE0NTXFokWL8h4HFIDcgNqSe/n49a9/HZdeemlceOGFERFxzDHHxNNPPx0bN27MexRQEHIDakvuL7t861vfirVr18Y777wTERG//e1v49VXX40LLrjggNf39/dHX1/fsAOoLSPNjQjZAdUs9ycfN9xwQ/T19cWsWbOivr4+BgYG4s4774x58+Yd8PqOjo649dZb814DqCIjzY0I2QHVLPcnHz//+c/jySefjKeeeiq2bNkSK1eujPvuuy9Wrlx5wOuXLl0avb29Q0dXV1feKwGj3EhzI0J2QDXL/cnHD3/4w7jhhhviiiuuiIiIE044IX7/+99HR0dHzJ8//wvXl0qlKJVKea8BVJGR5kaE7IBqlvuTjz179sSYMcN/2/r6+hgcHMx7FFAQcgNqS+5PPi6++OK4884746ijjorjjz8+fvOb38QDDzwQ3/3ud/MeBRSE3IDaknv5eOihh+Kmm26Ka6+9Nnbs2BHTpk2L733ve3HzzTfnPQooCLkBtSX38tHY2BidnZ3R2dmZ928NFJTcgNris10AgKSUDwAgKeUDAEhK+QAAklI+AICklA8AICnlAwBIKvf/5yMvu+/4h4jGxkqvURb/cfW7lV6hcB7/z2l79LhfPJB0XkTEYfe+n2xWNrA32ay83Xb+v4/GSZMqvUZZvLF9Z6VXKJz7tl+edN6N0/YknRcR0Xn/1iRzBvYe/L158gEAJKV8AABJKR8AQFLKBwCQlPIBACSlfAAASSkfAEBSygcAkJTyAQAkpXwAAEkpHwBAUsoHAJCU8gEAJKV8AABJKR8AQFLKBwCQlPIBACSlfAAASSkfAEBSygcAkJTyAQAkpXwAAEkpHwBAUsoHAJCU8gEAJKV8AABJKR8AQFINlV7gy/yP57ZFw/hDKr1GWSy8b2GlVyicP274u6Tzem58LOm8iIjrbutONqt/z6546Ionk83L0yX3ro/60sRKr1EW3/7xokqvUDj//ft/n3TeC1f/OOm8iIj/9B9KSeb079kVnY8e3LWefAAASSkfAEBSygcAkJTyAQAkpXwAAEkpHwBAUsoHAJCU8gEAJKV8AABJjbh8rF+/Pi6++OKYNm1a1NXVxXPPPTfs+1mWxc033xxTp06NCRMmRFtbW7z77rt57QtUIbkBfNaIy8fu3bvjpJNOiocffviA37/33nvjwQcfjOXLl8frr78ehxxySJx//vnx5z//+SsvC1QnuQF81og/2+WCCy6ICy644IDfy7IsOjs748Ybb4xLL700IiJ++tOfRnNzczz33HNxxRVXfLVtgaokN4DPyvU9Hx988EF0d3dHW1vb0LmmpqY4/fTT47XXXjvgr+nv74++vr5hB1A7/pbciJAdUM1yLR/d3f/6qZvNzc3Dzjc3Nw997/M6Ojqiqalp6JgxY0aeKwGj3N+SGxGyA6pZxX/aZenSpdHb2zt0dHV1VXoloArIDqheuZaPKVOmRERET0/PsPM9PT1D3/u8UqkUkyZNGnYAteNvyY0I2QHVLNfy0dLSElOmTIm1a9cOnevr64vXX389zjjjjDxHAQUhN6D2jPinXXbt2hXvvffe0NcffPBBbN26NQ4//PA46qijYvHixXHHHXfEzJkzo6WlJW666aaYNm1aXHbZZXnuDVQRuQF81ojLx6ZNm+Kcc84Z+nrJkiURETF//vx44okn4kc/+lHs3r07rr766vjTn/4UZ555ZqxevTrGjx+f39ZAVZEbwGeNuHycffbZkWXZl36/rq4ubrvttrjtttu+0mJAccgN4LMq/tMuAEBtUT4AgKSUDwAgKeUDAEhK+QAAklI+AICkRvyjtuW2/8fx/uXPeyq8SfnsGRyo9AqFs2vvvqTzdu/cmXReRET/nl3JZu3dszsi4i/+eOxos3/Xgb3FzY69MVjpFQpnZ+Ls2LOruNmxf87B5EZdNsrS5aOPPvLplDBKdHV1xfTp0yu9xkGRHTA6HExujLryMTg4GNu3b4/Gxsaoq6s76F/X19cXM2bMiK6ursJ+wFTR77Ho9xdRPfeYZVns3Lkzpk2bFmPGVMers7LjwIp+fxHFv8dqub+R5Maoe9llzJgxX+lfWrXw6ZZFv8ei319EddxjU1NTpVcYEdnxlxX9/iKKf4/VcH8HmxvV8U8aAKAwlA8AIKnClI9SqRS33HJLlEqlSq9SNkW/x6LfX0Rt3GO1KfqfSdHvL6L491jE+xt1bzgFAIqtME8+AIDqoHwAAEkpHwBAUsoHAJCU8gEAJFWY8vHwww/HMcccE+PHj4/TTz89Nm7cWOmVctHR0RGnnnpqNDY2xuTJk+Oyyy6Lt99+u9Jrlc3dd98ddXV1sXjx4kqvkquPP/44rrzyyjjiiCNiwoQJccIJJ8SmTZsqvVbNK2puRMiOoihqdhSifDzzzDOxZMmSuOWWW2LLli1x0kknxfnnnx87duyo9Gpf2bp166K9vT02bNgQa9asiX379sV5550Xu3fvrvRquXvjjTfikUceiRNPPLHSq+Tq008/jblz58bYsWPjpZdeirfeeivuv//+OOywwyq9Wk0rcm5EyI4iKHR2ZAVw2mmnZe3t7UNfDwwMZNOmTcs6OjoquFV57NixI4uIbN26dZVeJVc7d+7MZs6cma1ZsyY766yzsuuuu67SK+Xm+uuvz84888xKr8Hn1FJuZJnsqEZFzo6qf/Kxd+/e2Lx5c7S1tQ2dGzNmTLS1tcVrr71Wwc3Ko7e3NyIiDj/88Apvkq/29va48MILh/05FsULL7wQra2tcfnll8fkyZNjzpw5sWLFikqvVdNqLTciZEc1KnJ2VH35+OSTT2JgYCCam5uHnW9ubo7u7u4KbVUeg4ODsXjx4pg7d27Mnj270uvkZtWqVbFly5bo6Oio9Cpl8f7778eyZcti5syZ8fLLL8c111wTixYtipUrV1Z6tZpVS7kRITuqVZGzo6HSC3Dw2tvbY9u2bfHqq69WepXcdHV1xXXXXRdr1qyJ8ePHV3qdshgcHIzW1ta46667IiJizpw5sW3btli+fHnMnz+/wttRC2RHdSpydlT9k48jjzwy6uvro6enZ9j5np6emDJlSoW2yt/ChQvjxRdfjF/96lcxffr0Sq+Tm82bN8eOHTvi5JNPjoaGhmhoaIh169bFgw8+GA0NDTEwMFDpFb+yqVOnxnHHHTfs3LHHHhsffvhhhTaiVnIjQnZUsyJnR9WXj3HjxsUpp5wSa9euHTo3ODgYa9eujTPOOKOCm+Ujy7JYuHBhPPvss/HLX/4yWlpaKr1Srs4999x48803Y+vWrUNHa2trzJs3L7Zu3Rr19fWVXvErmzt37hd+xPGdd96Jo48+ukIbUfTciJAdsmOUq/Q7XvOwatWqrFQqZU888UT21ltvZVdffXX2ta99Levu7q70al/ZNddckzU1NWWvvPJK9oc//GHo2LNnT6VXK5uivWN948aNWUNDQ3bnnXdm7777bvbkk09mEydOzH72s59VerWaVuTcyDLZUQRFzo5ClI8sy7KHHnooO+qoo7Jx48Zlp512WrZhw4ZKr5SLiDjg8fjjj1d6tbIpWoBkWZb94he/yGbPnp2VSqVs1qxZ2aOPPlrplciKmxtZJjuKoqjZUZdlWVaZZy4AQC2q+vd8AADVRfkAAJJSPgCApJQPACAp5QMASEr5AACSUj4AgKSUDwAgKeUDAEhK+QAAklI+AICk/j93GupAeAW/MAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, axarr = plt.subplots(1,2)\n",
    "\n",
    "# vmax = torch.max(torch.abs(W))\n",
    "student.visualize_weights(axes=axarr[0])\n",
    "\n",
    "# # vmax = torch.max(torch.abs(V))\n",
    "teacher.visualize_weights(axes=axarr[1])\n",
    "\n",
    "# test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4f2cc761",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x717eaad7cbb0>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATwAAAGdCAYAAACPRc9NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUD0lEQVR4nO3df2zUdZ7H8Vd/2GnB6SzgtqWhSM9rglhk0QLBml2NjYSgK/cHkTtMCJtbDU4XuiS72kTgRHHUuFwjElAuQs2CsP/AqhcxpLtAOPnZLhuJCWDglgG2rSbuTGnPget874+9HXfkhy18v/OpfT8fyfePfufjfN4ZzTPf6Tjf5nme5wkADMh3PQAA5ArBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2BGoesBvimdTuvChQsKh8PKy8tzPQ6AIc7zPPX09KiyslL5+de/hhtywbtw4YKqqqpcjwHgOyYej2vcuHHXXTPkghcOhyVJ//1f/6nSW0c6neXpQ063lyS1/Hii6xGAIa2np0c1NTWZdlzPkAve397Glt46UqXhW53OUjTC6faSpNLSUtcjAN8JA/kVGB9aADCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMCOw4K1bt04TJkxQcXGxZsyYocOHDwe1FQAMSCDB2759u5YtW6aVK1eqo6NDU6ZM0axZs9Td3R3EdgAwIIEEb82aNfrpT3+qRYsWadKkSdqwYYNGjBiht99+O4jtAGBAfA/epUuX1N7eroaGhq83yc9XQ0ODDhw4cMX6VCqlZDKZdQBAEHwP3hdffKH+/n6Vl5dnnS8vL1dnZ+cV62OxmCKRSObgXngAguL8U9rm5mYlEonMEY/HXY8EYJjy/X54t912mwoKCtTV1ZV1vqurSxUVFVesD4VCCoVCfo8BAFfw/QqvqKhI9957r9ra2jLn0um02traNHPmTL+3A4ABC+SOx8uWLdPChQtVV1en6dOnq6WlRb29vVq0aFEQ2wHAgAQSvMcff1yff/65VqxYoc7OTv3gBz/Qrl27rvggAwByKbC/adHY2KjGxsagnh4ABs35p7QAkCsED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBmBfbXsZj19SCoa4XaGt+9zu78kXXI9ADCMcIUHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8CMQtcDXEvLjyeqtLTU6QyXnO7+V+Mff8P1CEPGuaZy1yNIktIz57keATeIKzwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZvgevFgspmnTpikcDqusrExz587ViRMn/N4GAAbN9+Dt3btX0WhUBw8e1O7du3X58mU9/PDD6u3t9XsrABgU3++Ht2vXrqyfN2/erLKyMrW3t+uHP/yh39sBwIAFfgPQRCIhSRo9evRVH0+lUkqlUpmfk8lk0CMBMCrQDy3S6bSamppUX1+v2traq66JxWKKRCKZo6qqKsiRABgWaPCi0aiOHz+ubdu2XXNNc3OzEolE5ojH40GOBMCwwN7SNjY26oMPPtC+ffs0bty4a64LhUIKhUJBjQEAGb4Hz/M8/exnP9OOHTu0Z88eVVdX+70FANwQ34MXjUa1detW/fa3v1U4HFZnZ6ckKRKJqKSkxO/tAGDAfP8d3vr165VIJPTAAw9o7NixmWP79u1+bwUAgxLIW1oAGIr4Li0AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzAr8BKG7OuaZy1yNIksa1dLkeYUjMIElnZ7qeADeKKzwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYUuh4A1zeupcv1CJKkc03lrkcYMq8Fvru4wgNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgRuDBe/nll5WXl6empqagtwKA6wo0eEeOHNGbb76pu+++O8htAGBAAgvexYsXtWDBAm3cuFGjRo0KahsAGLDAgheNRjVnzhw1NDRcd10qlVIymcw6ACAIgdzxeNu2bero6NCRI0e+dW0sFtPzzz8fxBgAkMX3K7x4PK6lS5dqy5YtKi4u/tb1zc3NSiQSmSMej/s9EgBICuAKr729Xd3d3brnnnsy5/r7+7Vv3z698cYbSqVSKigoyDwWCoUUCoX8HgMAruB78B566CF98sknWecWLVqkiRMn6plnnsmKHQDkku/BC4fDqq2tzTo3cuRIjRkz5orzAJBLfNMCgBk5+bu0e/bsycU2AHBdXOEBMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwIyfftMCN27/xKdcjSJLybulzPYLOb3Y9wV/1ux4AN4wrPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYUuh4A11d9S5/rESRJZy6PcD3CkHkt8N3FFR4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzAgne+fPn9cQTT2jMmDEqKSnR5MmTdfTo0SC2AoAB8/1uKV9++aXq6+v14IMP6sMPP9T3v/99nTp1SqNGjfJ7KwAYFN+D98orr6iqqkqbNm3KnKuurvZ7GwAYNN/f0r733nuqq6vTvHnzVFZWpqlTp2rjxo3XXJ9KpZRMJrMOAAiC78E7ffq01q9fr5qaGn300UdavHixlixZotbW1quuj8ViikQimaOqqsrvkQBAkpTneZ7n5xMWFRWprq5OH3/8cebckiVLdOTIER04cOCK9alUSqlUKvNzMplUVVWVOjs7VVpa6udo30kF/5NwPYIk7nj89/pLIq5HwN9JJpOqqKhQIpH41mb4foU3duxYTZo0KevcnXfeqbNnz151fSgUUmlpadYBAEHwPXj19fU6ceJE1rmTJ0/q9ttv93srABgU34P385//XAcPHtRLL72kzz77TFu3btVbb72laDTq91YAMCi+B2/atGnasWOH3n33XdXW1uqFF15QS0uLFixY4PdWADAogfyZxkceeUSPPPJIEE8NADeM79ICMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwI5BvWmD4GQq3ZhoKt6iSpPElrifAjeIKD4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBmFrge4lpP/+i+69ZYhO17OTP73V12PMGTc4XqA/3eu8B9cjyBJioQKXI+ggrw81yPocnrga7nCA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYIbvwevv79fy5ctVXV2tkpIS3XHHHXrhhRfkeZ7fWwHAoPh+O5JXXnlF69evV2trq+666y4dPXpUixYtUiQS0ZIlS/zeDgAGzPfgffzxx3rsscc0Z84cSdKECRP07rvv6vDhw35vBQCD4vtb2vvuu09tbW06efKkJOmPf/yj9u/fr9mzZ191fSqVUjKZzDoAIAi+X+E9++yzSiaTmjhxogoKCtTf36/Vq1drwYIFV10fi8X0/PPP+z0GAFzB9yu83/zmN9qyZYu2bt2qjo4Otba26rXXXlNra+tV1zc3NyuRSGSOeDzu90gAICmAK7xf/OIXevbZZzV//nxJ0uTJk/WnP/1JsVhMCxcuvGJ9KBRSKBTyewwAuILvV3h9fX3Kz89+2oKCAqXTg7jxPAAEwPcrvEcffVSrV6/W+PHjddddd+kPf/iD1qxZo5/85Cd+bwUAg+J78NauXavly5fr6aefVnd3tyorK/XUU09pxYoVfm8FAIPie/DC4bBaWlrU0tLi91MDwE3hu7QAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzPD9q2V+6X3xP6Rw2PUYzv3TrlOuR8A3bPrnoXGdUPT+GtcjaNSrp12PIK//0oDXDo1/cwCQAwQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBmFrge4ln/beVyFxSNdj+Fc42uNrkfAN3x+8B9djyBJ6nrubdcjaOmqTtcjKNV3UWvnbxnQWq7wAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmDHo4O3bt0+PPvqoKisrlZeXp507d2Y97nmeVqxYobFjx6qkpEQNDQ06deqUX/MCwA0bdPB6e3s1ZcoUrVu37qqPv/rqq3r99de1YcMGHTp0SCNHjtSsWbP01Vdf3fSwAHAzBn17qNmzZ2v27NlXfczzPLW0tOi5557TY489Jkl65513VF5erp07d2r+/Pk3Ny0A3ARff4d35swZdXZ2qqGhIXMuEoloxowZOnDgwFX/mVQqpWQymXUAQBB8DV5n519vBlheXp51vry8PPPYN8ViMUUikcxRVVXl50gAkOH8U9rm5mYlEonMEY/HXY8EYJjyNXgVFRWSpK6urqzzXV1dmce+KRQKqbS0NOsAgCD4Grzq6mpVVFSora0tcy6ZTOrQoUOaOXOmn1sBwKAN+lPaixcv6rPPPsv8fObMGR07dkyjR4/W+PHj1dTUpBdffFE1NTWqrq7W8uXLVVlZqblz5/o5NwAM2qCDd/ToUT344IOZn5ctWyZJWrhwoTZv3qxf/vKX6u3t1ZNPPqm//OUvuv/++7Vr1y4VFxf7NzUA3IBBB++BBx6Q53nXfDwvL0+rVq3SqlWrbmowAPCb809pASBXCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMG/U2LoP3tWxz/+1Wf40mGhr50v+sR8A0XL112PYIkqbenx/UISvVddD2CLvX1StJ1vwH2N3neQFbl0Llz57gJKIBBi8fjGjdu3HXXDLngpdNpXbhwQeFwWHl5eTf0HMlkUlVVVYrH4+bvr8drkY3X42vD5bXwPE89PT2qrKxUfv71f0s35N7S5ufnf2ulB4obin6N1yIbr8fXhsNrEYlEBrSODy0AmEHwAJgxLIMXCoW0cuVKhUIh16M4x2uRjdfjaxZfiyH3oQUABGVYXuEBwNUQPABmEDwAZhA8AGYMy+CtW7dOEyZMUHFxsWbMmKHDhw+7HinnYrGYpk2bpnA4rLKyMs2dO1cnTpxwPdaQ8PLLLysvL09NTU2uR3Hi/PnzeuKJJzRmzBiVlJRo8uTJOnr0qOuxcmLYBW/79u1atmyZVq5cqY6ODk2ZMkWzZs1Sd3e369Fyau/evYpGozp48KB2796ty5cv6+GHH1Zvb6/r0Zw6cuSI3nzzTd19992uR3Hiyy+/VH19vW655RZ9+OGH+vTTT/WrX/1Ko0aNcj1abnjDzPTp071oNJr5ub+/36usrPRisZjDqdzr7u72JHl79+51PYozPT09Xk1Njbd7927vRz/6kbd06VLXI+XcM888491///2ux3BmWF3hXbp0Se3t7WpoaMicy8/PV0NDgw4cOOBwMvcSiYQkafTo0Y4ncScajWrOnDlZ/31Y895776murk7z5s1TWVmZpk6dqo0bN7oeK2eGVfC++OIL9ff3q7y8POt8eXm5Ojs7HU3lXjqdVlNTk+rr61VbW+t6HCe2bdumjo4OxWIx16M4dfr0aa1fv141NTX66KOPtHjxYi1ZskStra2uR8uJIXe3FPgvGo3q+PHj2r9/v+tRnIjH41q6dKl2796t4uJi1+M4lU6nVVdXp5deekmSNHXqVB0/flwbNmzQwoULHU8XvGF1hXfbbbepoKBAXV1dWee7urpUUVHhaCq3Ghsb9cEHH+j3v/+9b7fd+q5pb29Xd3e37rnnHhUWFqqwsFB79+7V66+/rsLCQvX327mr9NixYzVp0qSsc3feeafOnj3raKLcGlbBKyoq0r333qu2trbMuXQ6rba2Ns2cOdPhZLnneZ4aGxu1Y8cO/e53v1N1dbXrkZx56KGH9Mknn+jYsWOZo66uTgsWLNCxY8dUUFDgesScqa+vv+J/Tzp58qRuv/12RxPlmOtPTfy2bds2LxQKeZs3b/Y+/fRT78knn/S+973veZ2dna5Hy6nFixd7kUjE27Nnj/fnP/85c/T19bkebUiw+int4cOHvcLCQm/16tXeqVOnvC1btngjRozwfv3rX7seLSeGXfA8z/PWrl3rjR8/3isqKvKmT5/uHTx40PVIOSfpqsemTZtcjzYkWA2e53ne+++/79XW1nqhUMibOHGi99Zbb7keKWe4PRQAM4bV7/AA4HoIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AM/4P7ziobOUtmtEAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "W = student.create_W()\n",
    "\n",
    "perm = torch.tensor([0,1,4,2,3,5,6,7,8,9,10])\n",
    "# perm = torch.tensor(range(11))\n",
    "perm_mat = torch.eye(W.shape[0])[perm]\n",
    "vmax = torch.max(torch.abs(W))\n",
    "plt.imshow((perm_mat@W).detach().numpy(),cmap='RdBu', vmax = vmax, vmin=-vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "14598f74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x717eaa373fd0>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATwAAAGdCAYAAACPRc9NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUEElEQVR4nO3df2zUdZ7H8Vd/2GnB6SzgtqWhQM9rglhk0QLBml2JjYSgK/eHkTs8G8ytBqcLXZJVmwicKI4Sl2tEAspGqFkQ9h9Y9SKGdBeIJ+VHuxiJCWDglhG2rSbuTGnPget874+9HXfkhy18Zz6l7+cj+f7R73ydzzsjeeYzHeZLjud5ngDAgFzXAwBAthA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGbkux7gu5LJpM6dO6dgMKicnBzX4wAY4jzPU09Pj8rLy5Wbe/U93JAL3rlz51RRUeF6DAA3mGg0qnHjxl31miEXvGAwKEn67//6TxXfPNLpLE8ddLq8JKn5p5NcjwAMaT09Paqqqkq142qGXPD+9ja2+OaRKg7e7HSWghFOl5ckFRcXux4BuCEM5FdgfGgBwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfAjIwFb/369Zo4caIKCws1c+ZMHTp0KFNLAcCAZCR4O3bs0LJly7Ry5Up1dHRo6tSpmjNnjrq7uzOxHAAMSEaCt3btWv3sZz/TokWLNHnyZG3cuFEjRozQW2+9lYnlAGBAfA/ehQsX1N7errq6um8Xyc1VXV2dDhw4cMn1iURC8Xg87QCATPA9eF999ZX6+/tVWlqadr60tFSdnZ2XXB+JRBQKhVIH98IDkCnOP6VtampSLBZLHdFo1PVIAIYp3++Hd8sttygvL09dXV1p57u6ulRWVnbJ9YFAQIFAwO8xAOASvu/wCgoKdNddd6m1tTV1LplMqrW1VbNmzfJ7OQAYsIzc8XjZsmWqr69XTU2NZsyYoebmZvX29mrRokWZWA4ABiQjwXvkkUf05ZdfasWKFers7NSPfvQj7d69+5IPMgAgmzL2b1o0NDSooaEhU08PAIPm/FNaAMgWggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8CMjH217Ho9dVAqGOF2hrfudru+JF1wPQAwjLDDA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgRr7rAa6k+aeTVFxc7HSGC05X/6vxj7zuegRJUsuax1yPMGTMnuD2zyWuHTs8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGb4HrxIJKLp06crGAyqpKRE8+fP1/Hjx/1eBgAGzffg7du3T+FwWG1tbdqzZ48uXryo+++/X729vX4vBQCD4vv98Hbv3p3285YtW1RSUqL29nb9+Mc/9ns5ABiwjN8ANBaLSZJGjx592ccTiYQSiUTq53g8numRABiV0Q8tksmkGhsbVVtbq+rq6steE4lEFAqFUkdFRUUmRwJgWEaDFw6HdezYMW3fvv2K1zQ1NSkWi6WOaDSayZEAGJaxt7QNDQ16//33tX//fo0bN+6K1wUCAQUCgUyNAQApvgfP8zz9/Oc/186dO7V3715VVlb6vQQAXBPfgxcOh7Vt2zb97ne/UzAYVGdnpyQpFAqpqKjI7+UAYMB8/x3ehg0bFIvFdO+992rs2LGpY8eOHX4vBQCDkpG3tAAwFPFdWgBmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGZk/AaguD4tax5zPYIkqf7pt12PMGReC9y42OEBMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATAj3/UAuDG0rHnM9Qiqf/pt1yNIks7saHA9Aq4ROzwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZmQ8eC+//LJycnLU2NiY6aUA4KoyGrzDhw/rjTfe0B133JHJZQBgQDIWvPPnz2vhwoXatGmTRo0alallAGDAMha8cDisefPmqa6u7qrXJRIJxePxtAMAMiEjdzzevn27Ojo6dPjw4e+9NhKJ6Pnnn8/EGACQxvcdXjQa1dKlS7V161YVFhZ+7/VNTU2KxWKpIxqN+j0SAEjKwA6vvb1d3d3duvPOO1Pn+vv7tX//fr3++utKJBLKy8tLPRYIBBQIBPweAwAu4Xvw7rvvPn366adp5xYtWqRJkybpmWeeSYsdAGST78ELBoOqrq5OOzdy5EiNGTPmkvMAkE180wKAGVn5d2n37t2bjWUA4KrY4QEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATAjK9+0wLW7dVSR6xEkSZU39bkeQWe3/KvrESRJ/a4HwDVjhwfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8CMfNcD4Ooqb+pzPYIk6fTFEa5HGDKvBW5c7PAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYkZHgnT17Vo8++qjGjBmjoqIiTZkyRUeOHMnEUgAwYL7fLeXrr79WbW2tZs+erQ8++EA//OEPdfLkSY0aNcrvpQBgUHwP3iuvvKKKigpt3rw5da6ystLvZQBg0Hx/S/vuu++qpqZGDz/8sEpKSjRt2jRt2rTpitcnEgnF4/G0AwAywffgnTp1Shs2bFBVVZU+/PBDLV68WEuWLFFLS8tlr49EIgqFQqmjoqLC75EAQJKU43me5+cTFhQUqKamRh9//HHq3JIlS3T48GEdOHDgkusTiYQSiUTq53g8roqKCnV2dqq4uNjP0W5Ief8Tcz2CJO54/Pf6i0KuR8DficfjKisrUywW+95m+L7DGzt2rCZPnpx27rbbbtOZM2cue30gEFBxcXHaAQCZ4Hvwamtrdfz48bRzJ06c0IQJE/xeCgAGxffg/eIXv1BbW5teeuklff7559q2bZvefPNNhcNhv5cCgEHxPXjTp0/Xzp079c4776i6ulovvPCCmpubtXDhQr+XAoBBycg/0/jAAw/ogQceyMRTA8A147u0AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzMjINy0w/AyFWzMNhVtUSdL4ItcT4FqxwwNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2BGvusBruTEv/2Lbr5pyI6XNVP+Y43rEYaMW10P8P++yP8H1yNIkkKBPNcjKC8nx/UIupgc+LXs8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJjhe/D6+/u1fPlyVVZWqqioSLfeeqteeOEFeZ7n91IAMCi+347klVde0YYNG9TS0qLbb79dR44c0aJFixQKhbRkyRK/lwOAAfM9eB9//LEeeughzZs3T5I0ceJEvfPOOzp06JDfSwHAoPj+lvbuu+9Wa2urTpw4IUn65JNP9NFHH2nu3LmXvT6RSCgej6cdAJAJvu/wnn32WcXjcU2aNEl5eXnq7+/X6tWrtXDhwsteH4lE9Pzzz/s9BgBcwvcd3m9/+1tt3bpV27ZtU0dHh1paWvTqq6+qpaXlstc3NTUpFouljmg06vdIACApAzu8X/7yl3r22We1YMECSdKUKVP0pz/9SZFIRPX19ZdcHwgEFAgE/B4DAC7h+w6vr69PubnpT5uXl6dkchA3ngeADPB9h/fggw9q9erVGj9+vG6//Xb98Y9/1Nq1a/X444/7vRQADIrvwVu3bp2WL1+up556St3d3SovL9eTTz6pFStW+L0UAAyK78ELBoNqbm5Wc3Oz308NANeF79ICMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMMP3r5b5pffFX0vBoOsxnPun3Sddj4Dv2PzPQ2OfUPDeWtcjaNSaU65HkNd/YcDXDo3/cwCQBQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBn5rge4kn/fdUz5hSNdj+Fcw6sNrkfAd3zZ9o+uR5AkdT33lusRtHRVp+sRlOg7r3ULtg7oWnZ4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzBh08Pbv368HH3xQ5eXlysnJ0a5du9Ie9zxPK1as0NixY1VUVKS6ujqdPHnSr3kB4JoNOni9vb2aOnWq1q9ff9nH16xZo9dee00bN27UwYMHNXLkSM2ZM0fffPPNdQ8LANdj0LeHmjt3rubOnXvZxzzPU3Nzs5577jk99NBDkqS3335bpaWl2rVrlxYsWHB90wLAdfD1d3inT59WZ2en6urqUudCoZBmzpypAwcOXPa/SSQSisfjaQcAZIKvwevs/OvNAEtLS9POl5aWph77rkgkolAolDoqKir8HAkAUpx/StvU1KRYLJY6otGo65EADFO+Bq+srEyS1NXVlXa+q6sr9dh3BQIBFRcXpx0AkAm+Bq+yslJlZWVqbW1NnYvH4zp48KBmzZrl51IAMGiD/pT2/Pnz+vzzz1M/nz59WkePHtXo0aM1fvx4NTY26sUXX1RVVZUqKyu1fPlylZeXa/78+X7ODQCDNujgHTlyRLNnz079vGzZMklSfX29tmzZoqefflq9vb164okn9Je//EX33HOPdu/ercLCQv+mBoBrMOjg3XvvvfI874qP5+TkaNWqVVq1atV1DQYAfnP+KS0AZAvBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgxqC/aZFpf/sWx/9+0+d4kqGhL9nvegR8x/kLF12PIEnq7elxPYISfeddj6ALfb2SdNVvgP1NjjeQq7Loiy++4CagAAYtGo1q3LhxV71myAUvmUzq3LlzCgaDysnJuabniMfjqqioUDQaNX9/PV6LdLwe3xour4Xneerp6VF5eblyc6/+W7oh95Y2Nzf3eys9UNxQ9Fu8Ful4Pb41HF6LUCg0oOv40AKAGQQPgBnDMniBQEArV65UIBBwPYpzvBbpeD2+ZfG1GHIfWgBApgzLHR4AXA7BA2AGwQNgBsEDYMawDN769es1ceJEFRYWaubMmTp06JDrkbIuEolo+vTpCgaDKikp0fz583X8+HHXYw0JL7/8snJyctTY2Oh6FCfOnj2rRx99VGPGjFFRUZGmTJmiI0eOuB4rK4Zd8Hbs2KFly5Zp5cqV6ujo0NSpUzVnzhx1d3e7Hi2r9u3bp3A4rLa2Nu3Zs0cXL17U/fffr97eXtejOXX48GG98cYbuuOOO1yP4sTXX3+t2tpa3XTTTfrggw/02Wef6Ve/+pVGjRrlerTs8IaZGTNmeOFwOPVzf3+/V15e7kUiEYdTudfd3e1J8vbt2+d6FGd6enq8qqoqb8+ePd5PfvITb+nSpa5HyrpnnnnGu+eee1yP4cyw2uFduHBB7e3tqqurS53Lzc1VXV2dDhw44HAy92KxmCRp9OjRjidxJxwOa968eWl/Pqx59913VVNTo4cfflglJSWaNm2aNm3a5HqsrBlWwfvqq6/U39+v0tLStPOlpaXq7Ox0NJV7yWRSjY2Nqq2tVXV1tetxnNi+fbs6OjoUiURcj+LUqVOntGHDBlVVVenDDz/U4sWLtWTJErW0tLgeLSuG3N1S4L9wOKxjx47po48+cj2KE9FoVEuXLtWePXtUWFjoehynksmkampq9NJLL0mSpk2bpmPHjmnjxo2qr693PF3mDasd3i233KK8vDx1dXWlne/q6lJZWZmjqdxqaGjQ+++/rz/84Q++3XbrRtPe3q7u7m7deeedys/PV35+vvbt26fXXntN+fn56u+3c1fpsWPHavLkyWnnbrvtNp05c8bRRNk1rIJXUFCgu+66S62tralzyWRSra2tmjVrlsPJss/zPDU0NGjnzp36/e9/r8rKStcjOXPffffp008/1dGjR1NHTU2NFi5cqKNHjyovL8/1iFlTW1t7yV9POnHihCZMmOBooixz/amJ37Zv3+4FAgFvy5Yt3meffeY98cQT3g9+8AOvs7PT9WhZtXjxYi8UCnl79+71/vznP6eOvr4+16MNCVY/pT106JCXn5/vrV692jt58qS3detWb8SIEd5vfvMb16NlxbALnud53rp167zx48d7BQUF3owZM7y2tjbXI2WdpMsemzdvdj3akGA1eJ7nee+9955XXV3tBQIBb9KkSd6bb77peqSs4fZQAMwYVr/DA4CrIXgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMz4P7RepZs4AfKLAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "total_nodes, equi_params, invar_params = proj_equi_params(perm_mat@W, symm_blocks, invar_nodes, requires_grad = True)\n",
    "W = EquiBlocks.create_W_(d,symm_blocks, invar_nodes, total_nodes, equi_params, invar_params)\n",
    "plt.imshow(W.detach().numpy(),cmap='RdBu', vmax = vmax, vmin=-vmax)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06edeb30",
   "metadata": {},
   "source": [
    "# Train Unconstrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5241cd79",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_loss_surface(W_current, W_x, W_y, V,\n",
    "                     x_min = -0.5, x_max = 0.5, y_min = -.5, y_max = .5, x_res = 10, y_res = 10):\n",
    "    xs = torch.linspace(x_min,x_max,x_res)\n",
    "    ys = torch.linspace(y_min,y_max,y_res)\n",
    "    zs = []\n",
    "    for x in xs:\n",
    "        zs.append([])\n",
    "        for y in ys:\n",
    "            zs[-1].append(F(W_current+x*W_x+y*W_y,V))\n",
    "    zs = torch.tensor(zs)\n",
    "    xs, ys = torch.meshgrid(xs,ys)\n",
    "    return xs, ys, zs\n",
    "\n",
    "def get_loss_line(W_0, W_step, V, range_step = [0,1], res = 40):\n",
    "    steps = torch.linspace(range_step[0],range_step[1],res)\n",
    "    loss = []\n",
    "    for step in steps:\n",
    "        loss.append(F(W_0+step*W_step,V).detach())\n",
    "    loss = torch.tensor(loss)\n",
    "    return steps, loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 291,
   "id": "b5dbe201",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x79032ca29720>"
      ]
     },
     "execution_count": 291,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATwAAAGdCAYAAACPRc9NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUEklEQVR4nO3df2zV9b3H8Vd/rKcFT88AbUtDkcY04UeRAQXCSqbEBkLQSXJD5k1NCEvmwk4HHcmmTQYoDo8um+nVEVASpctAWHKDMhNxpBslxgKllUVmLmDwzjPZaUeuO6fUeeD2fO8fdzvuSK0tfM/5lL6fj+T7x/meL/28cyhPvz1fz7d5nud5AgAD8l0PAAC5QvAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmFHoeoDPS6VSunTpkoLBoPLy8lyPA2CM8zxP/f39qqysVH7+8OdwYy54ly5dUlVVlesxANxiotGopk2bNuwxYy54wWBQkrSv4x1NuC3odJbDC5c7XV+SIrEu1yMAY1p/f79qamrS7RjOmAveP3+MnXBbUBMdB69oDLzFWVpa6noE4JYwkrfA3P+LBoAcIXgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMzIWvB27typGTNmqLi4WEuWLNGpU6eytRQAjEhWgnfw4EFt3rxZ27ZtU09Pj+bNm6eVK1eqr68vG8sBwIhkJXjPPvusvvOd72j9+vWaPXu2du/erQkTJuill17KxnIAMCK+B+/q1avq7u5WQ0PDZ4vk56uhoUGdnZ3XHZ9MJpVIJDI2AMgG34N3+fJlDQ4Oqry8PGN/eXm5YrHYdcdHIhGFQqH0xr3wAGSL86u0LS0tisfj6S0ajboeCcA45fv98G6//XYVFBSot7c3Y39vb68qKiquOz4QCCgQCPg9BgBcx/czvKKiIi1cuFDt7e3pfalUSu3t7Vq6dKnfywHAiGXljsebN2/WunXrVFdXp8WLF6u1tVUDAwNav359NpYDgBHJSvC+9a1v6a9//au2bt2qWCymr33tazpy5Mh1FzIAIJey9jstmpqa1NTUlK0vDwCj5vwqLQDkCsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgBsEDYAbBA2AGwQNgRtY+WnazDi9criLHPf63c/ziIWA84QwPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZha4H+CKRWJdKS0tdj+Fc/sn/dD3CmPH4iq2uR5AkbY3/0fUIuEGc4QEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEww/fgRSIRLVq0SMFgUGVlZVqzZo3OnTvn9zIAMGq+B6+jo0PhcFgnTpzQ0aNHde3aNa1YsUIDAwN+LwUAo+L7/fCOHDmS8Xjv3r0qKytTd3e3vvGNb/i9HACMWNZvABqPxyVJkydPHvL5ZDKpZDKZfpxIJLI9EgCjsnrRIpVKqbm5WfX19aqtrR3ymEgkolAolN6qqqqyORIAw7IavHA4rLNnz+rAgQNfeExLS4vi8Xh6i0aj2RwJgGFZ+5G2qalJr7/+uo4fP65p06Z94XGBQECBQCBbYwBAmu/B8zxP3//+93Xo0CEdO3ZM1dXVfi8BADfE9+CFw2Ht379fr732moLBoGKxmCQpFAqppKTE7+UAYMR8fw9v165disfjuvfeezV16tT0dvDgQb+XAoBRycqPtAAwFvFZWgBmEDwAZhA8AGYQPABmEDwAZhA8AGYQPABmEDwAZhA8AGZk/QaguDmPr9jqegRJ0uO/3e56hDExgySlXA+AG8YZHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwpdD4DhPf7b7a5HkCQ9vmKr6xHGzGuBWxdneADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMyHrwnn76aeXl5am5uTnbSwHAsLIavK6uLr3wwgu6++67s7kMAIxI1oJ35coVNTY2as+ePZo0aVK2lgGAEcta8MLhsFavXq2GhoZhj0smk0okEhkbAGRDVu54fODAAfX09Kirq+tLj41EInriiSeyMQYAZPD9DC8ajWrTpk3at2+fiouLv/T4lpYWxePx9BaNRv0eCQAkZeEMr7u7W319fVqwYEF63+DgoI4fP65f/OIXSiaTKigoSD8XCAQUCAT8HgMAruN78O677z69++67GfvWr1+vmTNn6tFHH82IHQDkku/BCwaDqq2tzdg3ceJETZky5br9AJBLfNICgBk5+b20x44dy8UyADAszvAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYkZNPWuDG/fesb7oeQZLU1LvG9Qi67HqAf5jsegDcMM7wAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmFHoegAM77aisfHfpCtXU65HGDOvBW5dfAcBMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwIyvB++ijj/Twww9rypQpKikp0dy5c3X69OlsLAUAI+b73VI+/vhj1dfXa/ny5XrjjTd0xx136MKFC5o0aZLfSwHAqPgevGeeeUZVVVV6+eWX0/uqq6v9XgYARs33H2kPHz6suro6rV27VmVlZZo/f7727Nnzhccnk0klEomMDQCywffgXbx4Ubt27VJNTY3efPNNbdiwQRs3blRbW9uQx0ciEYVCofRWVVXl90gAIEnK8zzP8/MLFhUVqa6uTm+//XZ638aNG9XV1aXOzs7rjk8mk0omk+nHiURCVVVVisViKi0t9XO0W9L/fDroegRJ3PH4X00uLnA9Av5FIpFQRUWF4vH4lzbD9++gqVOnavbs2Rn7Zs2apQ8//HDI4wOBgEpLSzM2AMgG34NXX1+vc+fOZew7f/687rzzTr+XAoBR8T14P/jBD3TixAk99dRTev/997V//369+OKLCofDfi8FAKPie/AWLVqkQ4cO6ZVXXlFtba2efPJJtba2qrGx0e+lAGBUsvJrGu+//37df//92fjSAHDDxsZlLwDIAYIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8CMrHzSAuPPWLg101i4RZXE7aFuZe6/iwEgRwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCB4AMwgeADMIHgAzCl0P8EVmNv6H8gqLXY/hXOzxua5HGDPucD3AP/RU3ON6BEnSggkJ1yMoXuz+b2XgWmrEx3KGB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAzfgzc4OKgtW7aourpaJSUluuuuu/Tkk0/K8zy/lwKAUfH9binPPPOMdu3apba2Ns2ZM0enT5/W+vXrFQqFtHHjRr+XA4AR8z14b7/9th588EGtXr1akjRjxgy98sorOnXqlN9LAcCo+P4j7de//nW1t7fr/PnzkqQ//OEPeuutt7Rq1aohj08mk0okEhkbAGSD72d4jz32mBKJhGbOnKmCggINDg5qx44damxsHPL4SCSiJ554wu8xAOA6vp/h/frXv9a+ffu0f/9+9fT0qK2tTT/72c/U1tY25PEtLS2Kx+PpLRqN+j0SAEjKwhneD3/4Qz322GN66KGHJElz587Vn/70J0UiEa1bt+664wOBgAKBgN9jAMB1fD/D++STT5Sfn/llCwoKlEqN/L7zAJANvp/hPfDAA9qxY4emT5+uOXPm6J133tGzzz6rb3/7234vBQCj4nvwnn/+eW3ZskXf+9731NfXp8rKSn33u9/V1q1b/V4KAEbF9+AFg0G1traqtbXV7y8NADeFz9ICMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMMP3j5b55ficiwoGilyP4VxB2dB3ioY7CyaMjbtyX2xpdj2COg/9l+sR9HdvcMTHcoYHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8AMggfADIIHwAyCB8CMPM/zPNdD/KtEIqFQKKTLF86oNBh0PY5zx+K3uR4Bn7P9tT+6HkGS9O/Lql2PoLVz7nA9gvoTCd1VVal4PK7S0tJhj+UMD4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBkED4AZBA+AGQQPgBmjDt7x48f1wAMPqLKyUnl5eXr11Vcznvc8T1u3btXUqVNVUlKihoYGXbhwwa95AeCGjTp4AwMDmjdvnnbu3Dnk8z/96U/13HPPaffu3Tp58qQmTpyolStX6tNPP73pYQHgZhSO9g+sWrVKq1atGvI5z/PU2tqqH//4x3rwwQclSb/85S9VXl6uV199VQ899NDNTQsAN8HX9/A++OADxWIxNTQ0pPeFQiEtWbJEnZ2dQ/6ZZDKpRCKRsQFANvgavFgsJkkqLy/P2F9eXp5+7vMikYhCoVB6q6qq8nMkAEhzfpW2paVF8Xg8vUWjUdcjARinfA1eRUWFJKm3tzdjf29vb/q5zwsEAiotLc3YACAbfA1edXW1Kioq1N7ent6XSCR08uRJLV261M+lAGDURn2V9sqVK3r//ffTjz/44AOdOXNGkydP1vTp09Xc3Kyf/OQnqqmpUXV1tbZs2aLKykqtWbPGz7kBYNRGHbzTp09r+fLl6cebN2+WJK1bt0579+7Vj370Iw0MDOiRRx7R3/72Ny1btkxHjhxRcXGxf1MDwA0YdfDuvfdeDfebHfPy8rR9+3Zt3779pgYDAL85v0oLALlC8ACYQfAAmEHwAJhB8ACYQfAAmEHwAJhB8ACYQfAAmDHqT1pk2z8/xdHff8XxJGPDwJUv/lQL3PjfTwdcjyBJ+vtAv+sR1J8IuB5B/f3//zoM9wmwf8rzRnJUDv35z3/mJqAARi0ajWratGnDHjPmgpdKpXTp0iUFg0Hl5eXd0NdIJBKqqqpSNBo1f389XotMvB6fGS+vhed56u/vV2VlpfLzh3+Xbsz9SJufn/+llR4pbij6GV6LTLwenxkPr0UoFBrRcVy0AGAGwQNgxrgMXiAQ0LZt2xQIuL+C5BqvRSZej89YfC3G3EULAMiWcXmGBwBDIXgAzCB4AMwgeADMGJfB27lzp2bMmKHi4mItWbJEp06dcj1SzkUiES1atEjBYFBlZWVas2aNzp0753qsMeHpp59WXl6empubXY/ixEcffaSHH35YU6ZMUUlJiebOnavTp0+7Hisnxl3wDh48qM2bN2vbtm3q6enRvHnztHLlSvX19bkeLac6OjoUDod14sQJHT16VNeuXdOKFSs0MDA2PvjuSldXl1544QXdfffdrkdx4uOPP1Z9fb2+8pWv6I033tB7772nn//855o0aZLr0XLDG2cWL17shcPh9OPBwUGvsrLSi0QiDqdyr6+vz5PkdXR0uB7Fmf7+fq+mpsY7evSod88993ibNm1yPVLOPfroo96yZctcj+HMuDrDu3r1qrq7u9XQ0JDel5+fr4aGBnV2djqczL14PC5Jmjx5suNJ3AmHw1q9enXG94c1hw8fVl1dndauXauysjLNnz9fe/bscT1Wzoyr4F2+fFmDg4MqLy/P2F9eXq5YLOZoKvdSqZSam5tVX1+v2tpa1+M4ceDAAfX09CgSibgexamLFy9q165dqqmp0ZtvvqkNGzZo48aNamtrcz1aToy5u6XAf+FwWGfPntVbb73lehQnotGoNm3apKNHj6q4uNj1OE6lUinV1dXpqaeekiTNnz9fZ8+e1e7du7Vu3TrH02XfuDrDu/3221VQUKDe3t6M/b29vaqoqHA0lVtNTU16/fXX9fvf/963227darq7u9XX16cFCxaosLBQhYWF6ujo0HPPPafCwkINDg66HjFnpk6dqtmzZ2fsmzVrlj788ENHE+XWuApeUVGRFi5cqPb29vS+VCql9vZ2LV261OFkued5npqamnTo0CH97ne/U3V1teuRnLnvvvv07rvv6syZM+mtrq5OjY2NOnPmjAoKClyPmDP19fXX/e9J58+f15133uloohxzfdXEbwcOHPACgYC3d+9e77333vMeeeQR76tf/aoXi8Vcj5ZTGzZs8EKhkHfs2DHvL3/5S3r75JNPXI82Jli9Snvq1CmvsLDQ27Fjh3fhwgVv37593oQJE7xf/epXrkfLiXEXPM/zvOeff96bPn26V1RU5C1evNg7ceKE65FyTtKQ28svv+x6tDHBavA8z/N+85vfeLW1tV4gEPBmzpzpvfjii65HyhluDwXAjHH1Hh4ADIfgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEwg+ABMIPgATCD4AEw4/8AfV+xVljzEDYAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "perm = torch.tensor([1,0,4,2,3,5,6,7,10,9,8])\n",
    "perm_mat = torch.eye(V.shape[0])[perm]\n",
    "vmax = torch.max(torch.abs(V))\n",
    "plt.imshow((perm_mat@V).detach().numpy(),cmap='RdBu', vmax = vmax, vmin=-vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 292,
   "id": "0acc64ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe\n",
       "    scrolling=\"no\"\n",
       "    width=\"100%\"\n",
       "    height=\"545px\"\n",
       "    src=\"iframe_figures/figure_292.html\"\n",
       "    frameborder=\"0\"\n",
       "    allowfullscreen\n",
       "></iframe>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "Wy = perm_mat@V-W\n",
    "# Wy /= torch.linalg.norm(Wy)\n",
    "xs, ys = get_loss_line(W,Wy,V,[-0.2,1.2])\n",
    "\n",
    "go.Figure(go.Scatter(x = xs, y = ys))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3087ddd4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "e3nn",
   "language": "python",
   "name": "e3nn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
