{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gEGVhOh_FCaN"
      },
      "source": [
        "In case the notebook is run on GoogleColab, the following dependencies need to be installed:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VVybMbgjFCNU"
      },
      "outputs": [],
      "source": [
        "#%pip --quiet install objax\n",
        "#%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git\n",
        "#%pip --quiet install pot\n",
        "#%pip install -U kaleido"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Also, a connection to a copy of the repository on GoogleDrive might be required:"
      ],
      "metadata": {
        "id": "V1NF0oeEWaJN"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#from google.colab import drive\n",
        "#drive.mount('/content/drive')\n",
        "\n",
        "#from google.colab import output\n",
        "#output.enable_custom_widget_manager()"
      ],
      "metadata": {
        "id": "VmI3A1WiWghw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wChyLedgFI2T"
      },
      "source": [
        "# Importing necessary packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xbR2TAy_E6f6"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "import os\n",
        "from functools import reduce\n",
        "\n",
        "path_to_repo = \"/../\" #'drive/MyDrive/invariant-mean-field-neural-networks'\n",
        "\n",
        "project_path = os.path.abspath(path_to_repo)\n",
        "sys.path.insert(1, project_path)\n",
        "\n",
        "import jax.numpy as jnp\n",
        "from jax import vmap\n",
        "\n",
        "import objax\n",
        "from emlp.reps import V,sparsify_basis\n",
        "from emlp.groups import SO,O,S,Z\n",
        "\n",
        "import src.modules\n",
        "from src.visualization import vis, particle_plot, plot_losses, particle_plot_animation\n",
        "from src.theory_utils import equivariance_err, Wasserstein_Distance, rel_measure_distance\n",
        "from src.modules import ShallowMLPNoLinearOut, FA_Model, SGD\n",
        "from src.train_eval_utils import random_compare, training_loop\n",
        "from src.utils import ExpData, CumData"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HCc8j1pwG5OW"
      },
      "source": [
        "# Definition of key group elements"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ichY9Fs_E9dx"
      },
      "outputs": [],
      "source": [
        "G = S(2)\n",
        "repin = V(G)\n",
        "repout = V(G)\n",
        "rep_params = (repin>>repout)\n",
        "P_params = rep_params.equivariant_projector()\n",
        "base_params = rep_params.equivariant_basis()\n",
        "G_generator = rep_params.rho_dense(G.discrete_generators[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We can visualize the base of this space:"
      ],
      "metadata": {
        "id": "ADRGIeD-LzrY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print(f\"Conv basis has shape {base_params.shape}\")\n",
        "vis(repin,repout)"
      ],
      "metadata": {
        "id": "w9s1Xx9fL2fj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Also, define vectorized applications of corresponding functions"
      ],
      "metadata": {
        "id": "VcI7ZUh9L36Q"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1lL4i0RhFxEw"
      },
      "outputs": [],
      "source": [
        "# Vectorized application of maps\n",
        "vP_params = vmap(lambda x: P_params@x)\n",
        "vbase_params = vmap(lambda x: base_params@x)\n",
        "vbase_paramsT = vmap(lambda x: base_params.T@x)\n",
        "\n",
        "# The orbit maps are, unfortunately, restricted to this specific case of S(2)\n",
        "vG_generator = vmap(lambda x: jnp.dot(G_generator,x))\n",
        "vorbit = (lambda x: jnp.vstack([x, vG_generator(x)])) # This generates an array of \"double the amount of particles\", but with the complete orbit of each point.\n",
        "\n",
        "scale_factor = 0.5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MCi6gw4hGFAr"
      },
      "source": [
        "We can understand the space of matrices as $R^4 \\cong R^{2\\times 2}$ (which we can visualize in $R^3$ with an extra axis of \"color\"). We can visualize the subspace of equivariant parameters $\\left(\\text{i.e. }\\;\\mathcal{E}^G\\right)$ as a subspace of dimension 2."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KRhA4N9RGA44"
      },
      "outputs": [],
      "source": [
        "equivariant_space_points = vbase_params((1.6*scale_factor)*jnp.array([[1,1],[1,-1],[-1,1],[-1,-1]]))\n",
        "particle_plot(None, None, equivariant_space_points, double=False, title=\"Equivariant Space\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c37wQjrlGRa7"
      },
      "source": [
        "We can also observe the \"orbits\" of points in this space, as well as their \"projected\" versions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2Icjd7FdGKzF"
      },
      "outputs": [],
      "source": [
        "example_points = scale_factor*jnp.array([[-0.5,0,1,-0.5], [0.5,0,0.5,1], [-0.5,0,-0.5,1]])\n",
        "point_array = jnp.vstack([vorbit(example_points), vP_params(example_points)])\n",
        "particle_plot(None, point_array, equivariant_space_points, c_lims=(-2*scale_factor, 2*scale_factor), double=False, with_lineplots = False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YNNYFvjV3e-l"
      },
      "source": [
        "# Reading Data from Files"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We add a function to create the fixed teacher model used for the experiments of the main paper:"
      ],
      "metadata": {
        "id": "ZBYfe8lsO1fP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def create_model(N_t, activation_fn, mode=\"free\", fixed_init = None):\n",
        "  if mode in [\"strong\",\"strong-equivariant\"]:\n",
        "      model = ShallowMLPNoLinearOut(N=N_t, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=True)\n",
        "      init_particles = model.get_particles() if fixed_init is None else fixed_init\n",
        "  elif mode in [\"weak\", \"weak-equivariant\"]:\n",
        "      model = ShallowMLPNoLinearOut(N=2*N_t, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)\n",
        "      init_particles = vorbit(model.get_particles()) if fixed_init is None else vorbit(fixed_init)\n",
        "  else:\n",
        "      model = ShallowMLPNoLinearOut(N=N_t, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)\n",
        "      init_particles = model.get_particles() if fixed_init is None else fixed_init\n",
        "  model.set_particles(init_particles)\n",
        "  return model"
      ],
      "metadata": {
        "id": "EF3qU-_bM8yc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "We also use a function to produce all the default hyperparameters:"
      ],
      "metadata": {
        "id": "MFqjLUhwO9_O"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iFPaoAxq3nNK"
      },
      "outputs": [],
      "source": [
        "def get_default_train_params(N_p, TEACHER_MODE, EQUIV_INIT):\n",
        "  N_t = 5\n",
        "  N_reps = 10\n",
        "  train_params = dict(\n",
        "      BATCH_SIZE = 20,\n",
        "      LR = 50,\n",
        "      T_EPOCHS = 20,\n",
        "      TAU = 1e-4, # This is the \"norm\" regularization parameter\n",
        "      BETA = 1e-6, # This is the \"noise\" parameter\n",
        "      DATA_STD = 4,\n",
        "      GRANULAR = 5,\n",
        "      EQUIV_INIT=EQUIV_INIT,\n",
        "      TEACHER_MODE=TEACHER_MODE,\n",
        "      N_reps = N_reps)\n",
        "  train_params[\"P_MAP\"] = vP_params if EQUIV_INIT else None\n",
        "  #Fixed Teacher\n",
        "  if TEACHER_MODE == \"strong\":\n",
        "    fixed_teacher_particles = (scale_factor)*jnp.array([[1,0],[0.5,1],[-0.5,0.3],[0,-1], [0.7, 0.7]])\n",
        "  else:\n",
        "    fixed_teacher_particles = (scale_factor)*jnp.array([[-1,0,0,0.5],[0.5,1,0,1],[-0.5,0.3,1,0],[0,-1,-0.5,1], [0.7, -0.7,0.5,0.7]])\n",
        "  activation_fn = objax.functional.sigmoid\n",
        "  teacher_network = create_model(N_t, activation_fn, mode= TEACHER_MODE, fixed_init=fixed_teacher_particles)\n",
        "  train_params[\"TEACHER_FIXED\"] = not fixed_teacher_particles is None\n",
        "  return train_params, teacher_network"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "For instance, we read the results for $N=100$"
      ],
      "metadata": {
        "id": "PLjU4bEUPDBP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "N_p = 100\n",
        "TEACHER_MODE = \"free\" # \"weak\" \"strong\" #\"free\"\n",
        "EQUIV_INIT = True # False-True\n",
        "path_read = \"\" #\"drive/MyDrive/Results/\"\n",
        "train_params, teacher_network = get_default_train_params(N_p, TEACHER_MODE, EQUIV_INIT)"
      ],
      "metadata": {
        "id": "xwM-zL9nMr4N"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "out_dict = ExpData([\"eq_errors\", \"dist_to_teacher\", \"dist_to_sym_teacher\",\n",
        "                      \"train_losses\", \"particles\"], N_p, train_params)\n",
        "dists_dict = ExpData([\"comparisons_RMDs\", \"comparisons_WDs\"], N_p, train_params)"
      ],
      "metadata": {
        "id": "J2GxVSzPMPSA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "out_dict.load(path_read + \"metrics\")\n",
        "dists_dict.load(path_read + \"measure_distances\")"
      ],
      "metadata": {
        "id": "4lPGmSHkNKmR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Visualizing the particles at the final iteration:"
      ],
      "metadata": {
        "id": "j3BTg2NmPWAw"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9Rd_uefnOULE"
      },
      "outputs": [],
      "source": [
        "repetition = 0\n",
        "particle_positions_teacher = teacher_network.get_particles()\n",
        "particles_vanilla = out_dict[\"particles\"][\"vanilla\"][repetition]\n",
        "particles_DA = out_dict[\"particles\"][\"DA\"][repetition]\n",
        "particles_FA = out_dict[\"particles\"][\"FA\"][repetition]\n",
        "particles_EA = out_dict[\"particles\"][\"EA\"][repetition]"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "For instance, let's observe the \"vanilla\" particles:"
      ],
      "metadata": {
        "id": "QvSU-AzFPnKP"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zr6PFfjEOULF"
      },
      "outputs": [],
      "source": [
        "particle_plot(particle_positions_teacher, particles_vanilla[-1], equivariant_space_points, title=\"Student Particles (vanilla)\", c_lims=(-0.65, 0.65), eq_style=(0.35, 0.7), with_lineplots=True, t_style=(3, \"diamond\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We can actually provide an animation of the entire training process:"
      ],
      "metadata": {
        "id": "DR4a0gkgPt0f"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "particle_plot_animation(particle_positions_teacher, particles_vanilla, equivariant_space_points, title=\"Student Particles (vanilla)\", c_lims=(-0.65, 0.65), eq_style=(0.35, 0.7), with_lineplots=True, t_style=(3, \"diamond\"))"
      ],
      "metadata": {
        "id": "0GjXT7mVPxm3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The interested reader might repeat the same experiment with other variants..."
      ],
      "metadata": {
        "id": "zCXV-yVPQijP"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Clt_xTi3-f6Y"
      },
      "source": [
        "# Tests Over Increasing N"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "If we want to see how these NNs behave according to the value of $N$, we can run the following script. We only consider $N=5,...,500$ to minimize the compute time of the upcoming section:"
      ],
      "metadata": {
        "id": "tsNy9a69QuXH"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yYedFXrtAiBx"
      },
      "outputs": [],
      "source": [
        "out_dicts = []\n",
        "Ns = [5,10,50,100,500] # [5,10,50,100,500, 1000, 5000]\n",
        "for N_p in Ns:\n",
        "  train_params, teacher_network = get_default_train_params(N_p, TEACHER_MODE, EQUIV_INIT)\n",
        "  out_dictk = ExpData([\"eq_errors\", \"dist_to_teacher\", \"dist_to_sym_teacher\",\n",
        "                      \"train_losses\", \"particles\"], N_p, train_params)\n",
        "  out_dictk.load(path_read + \"metrics\")\n",
        "  out_dicts.append(out_dictk)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## $L^2$ distances"
      ],
      "metadata": {
        "id": "Yr7CgsvuTFtr"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We can use the following function to obtain some of the plots comparing the $L^2$ distances:"
      ],
      "metadata": {
        "id": "30LGeUEORYrl"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eiGrxBTN-hTj"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "import colorcet as cc\n",
        "\n",
        "def get_list_dfs(out_dictks, key1, key2):\n",
        "  df_pre = [pd.DataFrame(out_dictt[key1][key2]) for out_dictt in out_dictks]\n",
        "  return df_pre\n",
        "\n",
        "def multi_N_boxplot(list_dfs, N_ps, title=\"value\"):\n",
        "  data_pre = [list_dfs[j].assign(Particles=N_p) for j,N_p in enumerate(N_ps)]\n",
        "\n",
        "  cdf = pd.concat(data_pre)#, data3])\n",
        "  mdf = pd.melt(cdf, id_vars=['Particles'], var_name=['Model'])\n",
        "  palette = sns.color_palette(\"colorblind\")\n",
        "  #palette = sns.color_palette(cc.glasbey_cool, n_colors=25)\n",
        "  ax = sns.boxplot(x=\"Particles\", y=\"value\", hue=\"Model\", data=mdf, palette=palette[:4])\n",
        "  ax.set_yscale(\"log\")\n",
        "  ax.set(xlabel='Number of Particles ($N$)', ylabel=\"$L^2$ Distance (estimated)\", title=title)\n",
        "  plt.tight_layout()\n",
        "  plt.savefig(title.replace(\" \", \"\")+\".pdf\", format=\"pdf\")\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2Q9_TorsH1UA"
      },
      "outputs": [],
      "source": [
        "multi_N_boxplot(get_list_dfs(out_dicts, \"eq_errors\", \"end\"), Ns, \"Distance to Symmetrized Model\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "axmYPbILDbCi"
      },
      "outputs": [],
      "source": [
        "multi_N_boxplot(get_list_dfs(out_dicts, \"dist_to_teacher\", \"end\"), Ns, \"Distance to Teacher\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0IFywGY9D8BI"
      },
      "outputs": [],
      "source": [
        "multi_N_boxplot(get_list_dfs(out_dicts, \"dist_to_sym_teacher\", \"end\"), Ns, \"Distance to Symmetrized Teacher\")"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## RMDs"
      ],
      "metadata": {
        "id": "PbN2P3_8TJje"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Useful Functions"
      ],
      "metadata": {
        "id": "sx6Kt6K3TLCt"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We introduce the functions used to calculate some of our quantities of interest:"
      ],
      "metadata": {
        "id": "YcxSMVgQSNh1"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s1ATVN5rHEmq"
      },
      "outputs": [],
      "source": [
        "def new_rel_measure_distance(m1, m2, p=2, root = False, mode = \"trace\", return_Wasserstein=False):\n",
        "    W_dist = Wasserstein_Distance(m1, m2, p=p, root=root)\n",
        "    mom1 = ((m1**p).sum(axis=1)).mean()\n",
        "    mom2 = ((m2**p).sum(axis=1)).mean()\n",
        "    total_variation = (mom1 + mom2) if not root else jnp.power((mom1 + mom2), 1/p)\n",
        "    return (2*W_dist/total_variation).item() if not return_Wasserstein else ((2*W_dist/total_variation).item(), W_dist)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vUomKf9AFyYz"
      },
      "outputs": [],
      "source": [
        "def distances(particles1, particles2):\n",
        "    RM = []\n",
        "    WD = []\n",
        "    for p1, p2 in zip(particles1, particles2):\n",
        "      rm, wd = new_rel_measure_distance(p1, p2, root=False, mode=\"trace\", return_Wasserstein=True)\n",
        "      RM.append(rm)\n",
        "      WD.append(wd)\n",
        "    return RM, WD"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y282Kl9MI8O8"
      },
      "outputs": [],
      "source": [
        "def RMD_comparisons_first_last(particles, equiv_init, teacher_mode):\n",
        "    L_particles_vanilla, L_particles_DA = particles[\"vanilla\"], particles[\"DA\"]\n",
        "    L_particles_FA, L_particles_EA = particles[\"FA\"], particles[\"EA\"]\n",
        "    L_comparisons_RMD = []\n",
        "    L_comparisons_WD = []\n",
        "    for particles_vanilla, particles_DA, particles_FA, particles_EA in zip(L_particles_vanilla, L_particles_DA, L_particles_FA, L_particles_EA):\n",
        "      particles_vanilla, particles_DA = [particles_vanilla[0], particles_vanilla[-1]], [particles_DA[0], particles_DA[-1]]\n",
        "      particles_FA, particles_EA = [particles_FA[0], particles_FA[-1]], [particles_EA[0], particles_EA[-1]]\n",
        "\n",
        "      comparisons_RMD = {}\n",
        "      comparisons_WD = {}\n",
        "      vvP_params =  lambda x: vmap(vP_params)(jnp.array(x))\n",
        "      vvorbit = lambda x: vmap(vorbit)(jnp.array(x))\n",
        "      if equiv_init:\n",
        "        comparisons_RMD[\"V vs. P(V)\"], comparisons_WD[\"V vs. P(V)\"] = distances(particles_vanilla, vvP_params(particles_vanilla))\n",
        "        comparisons_RMD[\"DA vs. P(DA)\"], comparisons_WD[\"DA vs. P(DA)\"] = distances(particles_DA, vvP_params(particles_DA))\n",
        "        comparisons_RMD[\"FA vs. P(FA)\"], comparisons_WD[\"FA vs. P(FA)\"] = distances(particles_FA, vvP_params(particles_FA))\n",
        "        comparisons_RMD[\"V vs. DA\"], comparisons_WD[\"V vs. DA\"] = distances(particles_vanilla, particles_DA)\n",
        "        comparisons_RMD[\"V vs. FA\"], comparisons_WD[\"V vs. FA\"] = distances(particles_vanilla, particles_FA)\n",
        "        comparisons_RMD[\"V vs. EA\"], comparisons_WD[\"V vs. EA\"] = distances(particles_vanilla, particles_EA)\n",
        "        comparisons_RMD[\"DA vs. FA\"], comparisons_WD[\"DA vs. FA\"] = distances(particles_DA, particles_FA)\n",
        "        comparisons_RMD[\"DA vs. EA\"], comparisons_WD[\"DA vs. EA\"] = distances(particles_DA, particles_EA)\n",
        "        comparisons_RMD[\"FA vs. EA\"], comparisons_WD[\"FA vs. EA\"] = distances(particles_FA, particles_EA)\n",
        "      else:\n",
        "        comparisons_RMD[\"V vs. G(V)\"], comparisons_WD[\"V vs. G(V)\"] = distances(particles_vanilla, vvorbit(particles_vanilla))\n",
        "        comparisons_RMD[\"DA vs. G(DA)\"], comparisons_WD[\"DA vs. G(DA)\"] = distances(particles_DA, vvorbit(particles_DA))\n",
        "        comparisons_RMD[\"FA vs. G(FA)\"], comparisons_WD[\"FA vs. G(FA)\"] = distances(particles_FA, vvorbit(particles_FA))\n",
        "        comparisons_RMD[\"V vs. DA\"], comparisons_WD[\"V vs. DA\"] = distances(particles_vanilla, particles_DA)\n",
        "        comparisons_RMD[\"V vs. FA\"], comparisons_WD[\"V vs. FA\"] = distances(particles_vanilla, particles_FA)\n",
        "        comparisons_RMD[\"DA vs. FA\"], comparisons_WD[\"DA vs. FA\"] = distances(particles_DA, particles_FA)\n",
        "        comparisons_RMD[\"P(FA) vs. EA\"], comparisons_WD[\"P(FA) vs. EA\"] = distances(vvP_params(particles_FA), particles_EA)\n",
        "      L_comparisons_RMD.append(comparisons_RMD)\n",
        "      L_comparisons_WD.append(comparisons_WD)\n",
        "    return L_comparisons_RMD, L_comparisons_WD"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R-iF4iEoLGPB"
      },
      "outputs": [],
      "source": [
        "def get_dict_from_RMD_dict(data):\n",
        "  # Initialize dictionaries for start and end dataframes\n",
        "  start_dict = {}\n",
        "  end_dict = {}\n",
        "\n",
        "  # Iterate through each dictionary in the list\n",
        "  for d in data:\n",
        "      for key, value in d.items():\n",
        "          if key not in start_dict:\n",
        "              start_dict[key] = []\n",
        "              end_dict[key] = []\n",
        "          start_dict[key].append(value[0])\n",
        "          end_dict[key].append(value[1])\n",
        "\n",
        "  return {\"start\":start_dict, \"end\":end_dict}"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Visualization"
      ],
      "metadata": {
        "id": "x5K6ZnflTNdd"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HJSsX9g2auyO"
      },
      "outputs": [],
      "source": [
        "# Calculate it\n",
        "data_pre = [pd.DataFrame(get_dict_from_RMD_dict(RMD_comparisons_first_last(out_dicts[j][\"particles\"], EQUIV_INIT, TEACHER_MODE)[0])[\"end\"]).assign(Particles=N_p) for j,N_p in enumerate(Ns)]\n",
        "pre_cdf = pd.concat(data_pre)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Comparison to symmetrized/projected version"
      ],
      "metadata": {
        "id": "ETDvPWYST9TX"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xHHU8hsfjUHY"
      },
      "outputs": [],
      "source": [
        "if EQUIV_INIT:\n",
        "  cdf = pre_cdf.drop(columns=[\"V vs. EA\", \"DA vs. EA\", \"FA vs. EA\", \"V vs. DA\", \"V vs. FA\", \"DA vs. FA\"]) #, data3])\n",
        "  mdf = pd.melt(cdf, id_vars=['Particles'], var_name=['Model'])\n",
        "  palette = sns.color_palette(\"colorblind\")\n",
        "  ax = sns.boxplot(x=\"Particles\", y=\"value\", hue=\"Model\", data=mdf, palette=palette[:3])\n",
        "  ax.set_yscale(\"log\")\n",
        "  ax.set(xlabel='Number of Particles ($N$)', ylabel=\"Relative Measure Distance\", title=\"Distance to Projected Version\")\n",
        "  plt.tight_layout()\n",
        "  plt.savefig(\"DistanceToEquivWithN.pdf\", format=\"pdf\")\n",
        "  plt.show()\n",
        "else:\n",
        "  cdf = pre_cdf.drop(columns=[\"V vs. DA\", \"V vs. FA\", \"DA vs. FA\", \"P(FA) vs. EA\"]) #, data3])\n",
        "  mdf = pd.melt(cdf, id_vars=['Particles'], var_name=['Model'])\n",
        "  # Setting color palette to colorblind-friendly\n",
        "  palette = sns.color_palette(\"colorblind\")\n",
        "  ax = sns.boxplot(x=\"Particles\", y=\"value\", hue=\"Model\", data=mdf, palette=palette[:3])\n",
        "  ax.set_yscale(\"log\")\n",
        "  ax.set(xlabel='Number of Particles ($N$)', ylabel=\"Relative Measure Distance\", title=\"Distance to Symmetrized Version\")\n",
        "  plt.tight_layout()\n",
        "  plt.savefig(\"DistanceToWeakWithN.pdf\", format=\"pdf\")\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Comparison between techniques:"
      ],
      "metadata": {
        "id": "AKcYdi-2UCbB"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gimSP7xujbfO"
      },
      "outputs": [],
      "source": [
        "if EQUIV_INIT:\n",
        "  cdf = pre_cdf.drop(columns=[\"V vs. EA\", \"DA vs. EA\", \"FA vs. EA\", \"DA vs. P(DA)\", \"FA vs. P(FA)\", \"V vs. P(V)\"]) #, data3])\n",
        "  mdf = pd.melt(cdf, id_vars=['Particles'], var_name=['Model'])\n",
        "  palette = sns.color_palette(\"colorblind\")\n",
        "  ax = sns.boxplot(x=\"Particles\", y=\"value\", hue=\"Model\", data=mdf, palette=palette[:3])\n",
        "  ax.set_yscale(\"log\")\n",
        "  ax.set(xlabel='Number of Particles ($N$)', ylabel=\"Relative Measure Distance\", title=\"Comparison between vanilla, DA and FA schemes\")\n",
        "  plt.tight_layout()\n",
        "  plt.savefig(\"DistanceBetweenTechniquesWithN.pdf\", format=\"pdf\")\n",
        "  plt.show()\n",
        "else:\n",
        "  cdf = pre_cdf.drop(columns=[\"V vs. G(V)\", \"DA vs. G(DA)\", \"FA vs. G(FA)\"]) #, data3])\n",
        "  mdf = pd.melt(cdf, id_vars=['Particles'], var_name=['Model'])\n",
        "\n",
        "  ax = sns.boxplot(x=\"Particles\", y=\"value\", hue=\"Model\", data=mdf, palette=palette[3:7])\n",
        "  ax.set_yscale(\"log\")\n",
        "  ax.set(xlabel='Number of Particles ($N$)', ylabel=\"Relative Measure Distance\", title=\"Distance between SL techniques after training\")\n",
        "  plt.tight_layout()\n",
        "  plt.savefig(\"DistanceBetweenTechniquesWithN.pdf\", format=\"pdf\")\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Comparison with EA:"
      ],
      "metadata": {
        "id": "56stG_fnUF5K"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "if EQUIV_INIT:\n",
        "  cdf = pre_cdf.drop(columns=[\"V vs. DA\", \"DA vs. FA\", \"V vs. FA\", \"DA vs. P(DA)\", \"FA vs. P(FA)\", \"V vs. P(V)\"]) #, data3])\n",
        "  mdf = pd.melt(cdf, id_vars=['Particles'], var_name=['Model'])\n",
        "  palette = sns.color_palette(\"colorblind\")\n",
        "  ax = sns.boxplot(x=\"Particles\", y=\"value\", hue=\"Model\", data=mdf, palette=palette[:3])\n",
        "  ax.set_yscale(\"log\")\n",
        "  ax.set(xlabel='Number of Particles ($N$)', ylabel=\"Relative Measure Distance\", title=\"Comparison between EA and other techniques\")\n",
        "  plt.tight_layout()\n",
        "  plt.savefig(\"DistanceToEAWithN.pdf\", format=\"pdf\")\n",
        "  plt.show()\n",
        "else:\n",
        "  print(\"No Equivalent plot!\")"
      ],
      "metadata": {
        "id": "PIBRXc4DTrXY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Dual Plot:"
      ],
      "metadata": {
        "id": "fIcq45MpUI03"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kyI_6FwutnZG"
      },
      "outputs": [],
      "source": [
        "if EQUIV_INIT:\n",
        "  # Preparing the first plot data\n",
        "  cdf1 = pre_cdf.drop(columns=[\"V vs. EA\", \"DA vs. EA\", \"FA vs. EA\", \"V vs. DA\", \"V vs. FA\", \"DA vs. FA\"])\n",
        "  mdf1 = pd.melt(cdf1, id_vars=['Particles'], var_name=['Model'])\n",
        "\n",
        "  # Preparing the second plot data\n",
        "  cdf2 = pre_cdf.drop(columns=[\"V vs. EA\", \"DA vs. EA\", \"FA vs. EA\", \"DA vs. P(DA)\", \"FA vs. P(FA)\", \"V vs. P(V)\"])\n",
        "  mdf2 = pd.melt(cdf2, id_vars=['Particles'], var_name=['Model'])\n",
        "\n",
        "  # Setting color palette to colorblind-friendly\n",
        "  palette = sns.color_palette(\"colorblind\")\n",
        "\n",
        "  # Creating the figure and axes\n",
        "  fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(12, 5))\n",
        "\n",
        "  # Plotting the first plot\n",
        "  sns.boxplot(ax=ax1, x=\"Particles\", y=\"value\", hue=\"Model\", data=mdf1, palette=palette[:3])\n",
        "  ax1.set_yscale(\"log\")\n",
        "  ax1.set(xlabel='Number of Particles ($N$)', ylabel=\"Relative Measure Distance (RMD)\", title=\"Distance to Projected Version\")\n",
        "\n",
        "  # Plotting the second plot\n",
        "  sns.boxplot(ax=ax2, x=\"Particles\", y=\"value\", hue=\"Model\", data=mdf2, palette=palette[3:6])\n",
        "  ax2.set_yscale(\"log\")\n",
        "  ax2.set(xlabel='Number of Particles ($N$)', title=\"Comparison between vanilla, DA and FA schemes\")\n",
        "\n",
        "  # Adjust layout and save the figure\n",
        "  plt.tight_layout()\n",
        "  plt.savefig(\"CombinedPlots.pdf\", format=\"pdf\")\n",
        "  plt.show()\n",
        "else:\n",
        "  # Preparing the first plot data\n",
        "  cdf1 = pre_cdf.drop(columns=[\"V vs. DA\", \"V vs. FA\", \"DA vs. FA\", \"P(FA) vs. EA\"])\n",
        "  mdf1 = pd.melt(cdf1, id_vars=['Particles'], var_name=['Model'])\n",
        "\n",
        "  # Preparing the second plot data\n",
        "  cdf2 = pre_cdf.drop(columns=[\"V vs. G(V)\", \"DA vs. G(DA)\", \"FA vs. G(FA)\", \"P(FA) vs. EA\"])\n",
        "  mdf2 = pd.melt(cdf2, id_vars=['Particles'], var_name=['Model'])\n",
        "\n",
        "  # Setting color palette to colorblind-friendly\n",
        "  palette = sns.color_palette(\"colorblind\")\n",
        "\n",
        "  # Creating the figure and axes\n",
        "  fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(12, 5))\n",
        "\n",
        "  # Plotting the first plot\n",
        "  sns.boxplot(ax=ax1, x=\"Particles\", y=\"value\", hue=\"Model\", data=mdf1, palette=palette[:3])\n",
        "  ax1.set_yscale(\"log\")\n",
        "  ax1.set(xlabel='Number of Particles ($N$)', ylabel=\"Relative Measure Distance (RMD)\", title=\"Distance to Symmetrized Version\")\n",
        "\n",
        "  # Plotting the second plot\n",
        "  sns.boxplot(ax=ax2, x=\"Particles\", y=\"value\", hue=\"Model\", data=mdf2, palette=palette[3:6])\n",
        "  ax2.set_yscale(\"log\")\n",
        "  ax2.set(xlabel='Number of Particles ($N$)', title=\"Comparison between vanilla, DA and FA schemes\")\n",
        "\n",
        "  # Adjust layout and save the figure\n",
        "  plt.tight_layout()\n",
        "  plt.savefig(\"CombinedPlots.pdf\", format=\"pdf\")\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Saving Particle Plots (optional)"
      ],
      "metadata": {
        "id": "6zgMw2_nURxu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import plotly.graph_objects as go\n",
        "import plotly.io as pio\n",
        "import jax.numpy as jnp\n",
        "\n",
        "def create_plot_list(teacher_coords=None, particle_coords=None, equivariant_coords=None, c_lims=(-2, 2), t_style=(4, \"diamond\"), p_style=(1.5, \"circle\"), eq_style=(0.35, 0.8), with_lineplots=True, showlegend=True, with_colorbar=True):\n",
        "    cmin, cmax = c_lims\n",
        "    plots = []\n",
        "    color_scale = \"Viridis\"#\"curl\"#\"plasma\"#\"turbo\"#\"Rainbow\"#\"Jet\"#\"Viridis\"#\"plasma\"#\"curl\"\n",
        "    showscale = with_colorbar\n",
        "\n",
        "    if equivariant_coords is not None:\n",
        "        ex1, ex2, ex3, ex4 = equivariant_coords\n",
        "        equivariant_space_mesh = go.Mesh3d(\n",
        "            x=ex1, y=ex2, z=ex3, opacity=eq_style[0], intensity=ex4,\n",
        "            colorscale=color_scale, cmax=cmax, cmin=cmin,\n",
        "            name='y', showscale=showscale, hoverinfo='none', showlegend=False)\n",
        "        plots.append(equivariant_space_mesh)\n",
        "\n",
        "    if teacher_coords is not None:\n",
        "        gtx1, gtx2, gtx3, gtx4 = teacher_coords\n",
        "        scatter_gt = go.Scatter3d(\n",
        "            name='Teacher Particles', x=gtx1, y=gtx2, z=gtx3, mode='markers',\n",
        "            marker=dict(size=t_style[0], color=gtx4, cmax=cmax, cmin=cmin,\n",
        "                        colorscale=color_scale, showscale=showscale, symbol=t_style[1]),\n",
        "            hovertemplate='<b>Teacher Particle</b><extra></extra>',\n",
        "            showlegend=showlegend)\n",
        "        plots.append(scatter_gt)\n",
        "\n",
        "    if particle_coords is not None:\n",
        "        x1, x2, x3, x4 = particle_coords\n",
        "        scatter_particles = go.Scatter3d(\n",
        "            name='Student Particles', x=x1, y=x2, z=x3, mode='markers',\n",
        "            marker=dict(size=p_style[0], color=x4, cmax=cmax, cmin=cmin,\n",
        "                        colorscale=color_scale, showscale=showscale, symbol=p_style[1]),\n",
        "            hoverinfo='none', showlegend=showlegend)\n",
        "        plots.append(scatter_particles)\n",
        "\n",
        "    if with_lineplots and teacher_coords is not None:\n",
        "        lines = jnp.einsum('a,bc->abc', jnp.linspace(0, eq_style[1], 50),\n",
        "                           (teacher_coords / jnp.linalg.norm(teacher_coords.T, axis=1)).T).swapaxes(0, 1)\n",
        "        for line in lines:\n",
        "            lx1, lx2, lx3, lx4 = line.T\n",
        "            plots.append(go.Scatter3d(\n",
        "                x=lx1, y=lx2, z=lx3, marker=dict(size=0.01),\n",
        "                line=dict(color='black', width=0.5, dash='dash'),\n",
        "                showlegend=False, hoverinfo=\"none\"))\n",
        "\n",
        "    return plots\n",
        "\n",
        "def get_transpose(x):\n",
        "    return x.T if x is not None else None\n",
        "\n",
        "def particle_plot_save(particle_positions_teacher, particle_positions_model, equivariant_space_limits, title=\"4D plot\", c_lims=(-2, 2), t_style=(4, \"diamond\"), p_style=(1.5, \"circle\"), eq_style=(0.35, 0.8), with_lineplots=True, showlegend=True, with_colorbar=True, show_title=True):\n",
        "    teacher_coords, particle_coords, equivariant_coords = get_transpose(particle_positions_teacher), get_transpose(particle_positions_model), get_transpose(equivariant_space_limits)\n",
        "    plots = create_plot_list(teacher_coords, particle_coords, equivariant_coords, c_lims=c_lims, t_style=t_style, p_style=p_style, eq_style=eq_style, with_lineplots=with_lineplots, showlegend=showlegend, with_colorbar=with_colorbar)\n",
        "\n",
        "    fig = go.Figure()\n",
        "\n",
        "    for plot in plots:\n",
        "        fig.add_trace(plot)\n",
        "\n",
        "    layout_config = dict(\n",
        "        title=dict(text=title if show_title else \"\", x=0.5, font=dict(size=20), automargin=False),\n",
        "        scene=dict(\n",
        "            xaxis_title=\"z1\",\n",
        "            yaxis_title=\"z2\",\n",
        "            zaxis_title=\"z3\",\n",
        "            xaxis_showspikes=False,\n",
        "            yaxis_showspikes=False,\n",
        "            zaxis_showspikes=False,\n",
        "            camera=dict(\n",
        "                up=dict(x=0, y=0, z=1),\n",
        "                center=dict(x=0, y=0, z=0),\n",
        "                eye=dict(x=1.25, y=-1.25, z=1.25) if show_title else dict(x=1.2, y=1.2, z=1.2)\n",
        "            ),\n",
        "            xaxis=dict(showgrid=True, gridcolor='lightgray'),\n",
        "            yaxis=dict(showgrid=True, gridcolor='lightgray'),\n",
        "            zaxis=dict(showgrid=True, gridcolor='lightgray')\n",
        "        ),\n",
        "        width=800,\n",
        "        height=760 if show_title else 720,\n",
        "        autosize=False,\n",
        "        legend=dict(\n",
        "            yanchor=\"bottom\",\n",
        "            y=0.99,\n",
        "            xanchor=\"left\",\n",
        "            x=0.01\n",
        "        )\n",
        "    )\n",
        "\n",
        "    fig.update_layout(layout_config)\n",
        "    if with_colorbar:\n",
        "      fig.add_annotation(\n",
        "          text=\"z4\",\n",
        "          x=1.07,\n",
        "          y=1.02,\n",
        "          showarrow=False,\n",
        "          xref=\"paper\",\n",
        "          yref=\"paper\",\n",
        "          font=dict(size=12)\n",
        "      )\n",
        "    fig.show()\n",
        "    #pio.renderers.default = 'browser'\n",
        "    # Save the figure as a PDF\n",
        "    pretitle = title.replace(\" \", \"\") + \"p1\" if show_title else title.replace(\" \", \"\") + \"p2\"\n",
        "    pio.write_image(fig, pretitle + '.pdf', format='pdf', validate=True)\n",
        "\n",
        "# Example usage:\n",
        "# particle_plot(particle_positions_teacher, particle_positions_model, equivariant_space_limits, title=\"Plot 1\", with_colorbar=True, show_title=True, showlegend=True)\n",
        "# particle_plot(particle_positions_teacher, particle_positions_model, equivariant_space_limits, title=\"Plot 2\", with_colorbar=False, show_title=False, showlegend=False)\n"
      ],
      "metadata": {
        "id": "RTPoYEi2fbiJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "repetition = 0\n",
        "particle_positions_teacher = teacher_network.get_particles()\n",
        "particles_vanilla = out_dict[\"particles\"][\"vanilla\"][repetition]\n",
        "particles_DA = out_dict[\"particles\"][\"DA\"][repetition]\n",
        "particles_FA = out_dict[\"particles\"][\"FA\"][repetition]\n",
        "particles_EA = out_dict[\"particles\"][\"EA\"][repetition]"
      ],
      "metadata": {
        "id": "VjZ5PAmFfwCd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The different hyperparameters might be changed to obtain the different parts of our particle plots:"
      ],
      "metadata": {
        "id": "9OnQPnLCUoFs"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "particle_plot_save(particle_positions_teacher, particles_EA[-1], equivariant_space_points, title=\"Student Particles (EA)\", c_lims=(-0.65, 0.65), eq_style=(0.35, 0.6), with_lineplots=True, t_style=(3, \"diamond\"), with_colorbar=True, show_title=True, showlegend=True)"
      ],
      "metadata": {
        "id": "EzSPl0wqFx_E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "QkE4ZhNXVvE2"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "6zgMw2_nURxu"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}