{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gEGVhOh_FCaN"
      },
      "source": [
        "In case the notebook is run on GoogleColab, the following dependencies need to be installed:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VVybMbgjFCNU"
      },
      "outputs": [],
      "source": [
        "#%pip --quiet install objax\n",
        "#%pip --quiet install git+https://github.com/mfinzi/equivariant-MLP.git\n",
        "#%pip --quiet install pot\n",
        "#%pip install -U kaleido"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Also, a connection to a copy of the repository on GoogleDrive might be required:"
      ],
      "metadata": {
        "id": "V1NF0oeEWaJN"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#from google.colab import drive\n",
        "#drive.mount('/content/drive')\n",
        "\n",
        "#from google.colab import output\n",
        "#output.enable_custom_widget_manager()"
      ],
      "metadata": {
        "id": "VmI3A1WiWghw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wChyLedgFI2T"
      },
      "source": [
        "# Importing necessary packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xbR2TAy_E6f6"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "import os\n",
        "from functools import reduce\n",
        "\n",
        "path_to_repo = \"/../\" #'drive/MyDrive/invariant-mean-field-neural-networks'\n",
        "\n",
        "project_path = os.path.abspath(path_to_repo)\n",
        "sys.path.insert(1, project_path)\n",
        "\n",
        "import jax.numpy as jnp\n",
        "from jax import vmap\n",
        "\n",
        "import objax\n",
        "from emlp.reps import V,sparsify_basis\n",
        "from emlp.groups import SO,O,S,Z\n",
        "\n",
        "import src.modules\n",
        "from src.visualization import vis, particle_plot, plot_losses, particle_plot_animation\n",
        "from src.theory_utils import equivariance_err, Wasserstein_Distance, rel_measure_distance\n",
        "from src.modules import ShallowMLPNoLinearOut, FA_Model, SGD\n",
        "from src.train_eval_utils import random_compare, training_loop\n",
        "from src.utils import ExpData, CumData"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HCc8j1pwG5OW"
      },
      "source": [
        "# Definition of key group elements"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ichY9Fs_E9dx"
      },
      "outputs": [],
      "source": [
        "G = S(2)\n",
        "repin = V(G)\n",
        "repout = V(G)\n",
        "rep_params = (repin>>repout)\n",
        "P_params = rep_params.equivariant_projector()\n",
        "base_params = rep_params.equivariant_basis()\n",
        "G_generator = rep_params.rho_dense(G.discrete_generators[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We can visualize the base of this space:"
      ],
      "metadata": {
        "id": "ADRGIeD-LzrY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "print(f\"Conv basis has shape {base_params.shape}\")\n",
        "vis(repin,repout)"
      ],
      "metadata": {
        "id": "w9s1Xx9fL2fj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Also, define vectorized applications of corresponding functions"
      ],
      "metadata": {
        "id": "VcI7ZUh9L36Q"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1lL4i0RhFxEw"
      },
      "outputs": [],
      "source": [
        "# Vectorized application of maps\n",
        "vP_params = vmap(lambda x: P_params@x)\n",
        "vbase_params = vmap(lambda x: base_params@x)\n",
        "vbase_paramsT = vmap(lambda x: base_params.T@x)\n",
        "\n",
        "# The orbit maps are, unfortunately, restricted to this specific case of S(2)\n",
        "vG_generator = vmap(lambda x: jnp.dot(G_generator,x))\n",
        "vorbit = (lambda x: jnp.vstack([x, vG_generator(x)])) # This generates an array of \"double the amount of particles\", but with the complete orbit of each point.\n",
        "\n",
        "scale_factor = 0.5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MCi6gw4hGFAr"
      },
      "source": [
        "We can understand the space of matrices as $R^4 \\cong R^{2\\times 2}$ (which we can visualize in $R^3$ with an extra axis of \"color\"). We can visualize the subspace of equivariant parameters $\\left(\\text{i.e. }\\;\\mathcal{E}^G\\right)$ as a subspace of dimension 2."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KRhA4N9RGA44"
      },
      "outputs": [],
      "source": [
        "equivariant_space_points = vbase_params((1.6*scale_factor)*jnp.array([[1,1],[1,-1],[-1,1],[-1,-1]]))\n",
        "particle_plot(None, None, equivariant_space_points, double=False, title=\"Equivariant Space\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c37wQjrlGRa7"
      },
      "source": [
        "We can also observe the \"orbits\" of points in this space, as well as their \"projected\" versions:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2Icjd7FdGKzF"
      },
      "outputs": [],
      "source": [
        "example_points = scale_factor*jnp.array([[-0.5,0,1,-0.5], [0.5,0,0.5,1], [-0.5,0,-0.5,1]])\n",
        "point_array = jnp.vstack([vorbit(example_points), vP_params(example_points)])\n",
        "particle_plot(None, point_array, equivariant_space_points, c_lims=(-2*scale_factor, 2*scale_factor), double=False, with_lineplots = False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YNNYFvjV3e-l"
      },
      "source": [
        "# Reading Data from File"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We add a function to create the fixed teacher model used for the experiments of the main paper:"
      ],
      "metadata": {
        "id": "ZBYfe8lsO1fP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def create_model(N_t, activation_fn, mode=\"free\", fixed_init = None):\n",
        "  if mode in [\"strong\",\"strong-equivariant\"]:\n",
        "      model = ShallowMLPNoLinearOut(N=N_t, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=True)\n",
        "      init_particles = model.get_particles() if fixed_init is None else fixed_init\n",
        "  elif mode in [\"weak\", \"weak-equivariant\"]:\n",
        "      model = ShallowMLPNoLinearOut(N=2*N_t, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)\n",
        "      init_particles = vorbit(model.get_particles()) if fixed_init is None else vorbit(fixed_init)\n",
        "  else:\n",
        "      model = ShallowMLPNoLinearOut(N=N_t, rep_in=repin, rep_out=repout, activation=activation_fn, alternative=True, use_bias=False, equivariant=False)\n",
        "      init_particles = model.get_particles() if fixed_init is None else fixed_init\n",
        "  model.set_particles(init_particles)\n",
        "  return model"
      ],
      "metadata": {
        "id": "EF3qU-_bM8yc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "We also use a function to produce all the default hyperparameters:"
      ],
      "metadata": {
        "id": "MFqjLUhwO9_O"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iFPaoAxq3nNK"
      },
      "outputs": [],
      "source": [
        "def get_default_train_params(N_p, TEACHER_MODE):\n",
        "  # These are the default parameters for the Heuristic!\n",
        "  N_t = 5\n",
        "  N_reps = 10\n",
        "  train_params = dict(\n",
        "      BATCH_SIZE = 20,\n",
        "      LR = 20,\n",
        "      T_EPOCHS = 20,\n",
        "      TAU = 1e-4, # This is the \"norm\" regularization parameter\n",
        "      BETA = 1e-6, # This is the \"noise\" parameter\n",
        "      DATA_STD = 4,\n",
        "      GRANULAR = 5,\n",
        "      EQUIV_INIT=False,\n",
        "      TEACHER_MODE=TEACHER_MODE,\n",
        "      N_reps = N_reps)\n",
        "  train_params[\"P_MAP\"] = None\n",
        "  #Fixed Teacher\n",
        "  if TEACHER_MODE == \"strong\":\n",
        "    fixed_teacher_particles = (scale_factor)*jnp.array([[1,0],[0.5,1],[-0.5,0.3],[0,-1], [0.7, 0.7]])\n",
        "  else:\n",
        "    fixed_teacher_particles = (scale_factor)*jnp.array([[-1,0,0,0.5],[0.5,1,0,1],[-0.5,0.3,1,0],[0,-1,-0.5,1], [0.7, -0.7,0.5,0.7]])\n",
        "  activation_fn = objax.functional.sigmoid\n",
        "  teacher_network = create_model(N_t, activation_fn, mode= TEACHER_MODE, fixed_init=fixed_teacher_particles)\n",
        "  train_params[\"TEACHER_FIXED\"] = not fixed_teacher_particles is None\n",
        "  return train_params, teacher_network"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "For instance, we read the results for $N=1000$"
      ],
      "metadata": {
        "id": "PLjU4bEUPDBP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "N_p = 1000\n",
        "TEACHER_MODE = \"weak\" # \"weak\" \"strong\" #\"free\"\n",
        "path_read = \"\" #\"drive/MyDrive/Results/\"\n",
        "train_params, teacher_network = get_default_train_params(N_p, TEACHER_MODE)"
      ],
      "metadata": {
        "id": "xwM-zL9nMr4N"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "out_dict = ExpData([\"eq_errors\", \"dist_to_teacher\", \"comparisons_RMD\", \"comparisons_WD\",\n",
        "                      \"train_losses\", \"particles\"], N_p, train_params)"
      ],
      "metadata": {
        "id": "J2GxVSzPMPSA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "out_dict.load(path_read + \"Heuristic\")"
      ],
      "metadata": {
        "id": "4lPGmSHkNKmR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Visualizing the steps of the Heuristic via RMDs"
      ],
      "metadata": {
        "id": "j3BTg2NmPWAw"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We consider the following useful function:"
      ],
      "metadata": {
        "id": "iF-yMWoUgZ-6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "\n",
        "def get_df_from_dict(out_dict):\n",
        "  # Initialize a dictionary to store the extracted data\n",
        "  data_dict = {\n",
        "      \"Phase\": [],\n",
        "      \"Repetition\": [],\n",
        "      \"V vs. P_0(V)\": [],\n",
        "      \"V vs. P_E(V)\": []\n",
        "  }\n",
        "\n",
        "  # Iterate through the phases and repetitions to extract the last timestep values\n",
        "  phases = [\"fase0_s\", \"fase0_e\", \"fase1_s\", \"fase1_e\", \"fase2_s\", \"fase2_e\"]\n",
        "  for phase in phases:\n",
        "      for j in range(10):\n",
        "          df = pd.DataFrame(out_dict[\"comparisons_RMD\"][phase][j])\n",
        "          last_row = df.iloc[-1]  # Get the last row\n",
        "\n",
        "          data_dict[\"Phase\"].append(phase)\n",
        "          data_dict[\"Repetition\"].append(j)\n",
        "          data_dict[\"V vs. P_0(V)\"].append(last_row[\"V vs. P_0(V)\"])\n",
        "          data_dict[\"V vs. P_E(V)\"].append(last_row[\"V vs. P_E(V)\"])\n",
        "  return pd.DataFrame(data_dict).replace({\"Phase\":{\"fase0_e\": \"End of Step 0\", \"fase1_e\": \"End of Step 1\", \"fase2_e\": \"End of Step 2\", \"fase0_s\": \"Start of Step 0\", \"fase1_s\": \"Start of Step 1\", \"fase2_s\": \"Start of Step 2\"}})"
      ],
      "metadata": {
        "id": "dF_TBF5ia3BS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "We visualize the steps of the heuristic, together the relevant RMDs:"
      ],
      "metadata": {
        "id": "cHXenfNtgY56"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "data_df = get_df_from_dict(out_dict)\n",
        "\n",
        "# Melt the DataFrame for easier plotting with seaborn\n",
        "melted_df = pd.melt(data_df, id_vars=[\"Phase\", \"Repetition\"],\n",
        "                    value_vars=[\"V vs. P_0(V)\", \"V vs. P_E(V)\"],\n",
        "                    var_name=\"Comparison\", value_name=\"Value\")\n",
        "\n",
        "\n",
        "palette = sns.color_palette(\"colorblind\")\n",
        "\n",
        "# Create the boxplot\n",
        "plt.figure(figsize=(8, 5))\n",
        "ax = sns.boxplot(x=\"Phase\", y=\"Value\", hue=\"Comparison\", data=melted_df, palette=palette[:2])#, showfliers=False)\n",
        "plt.axhline(y=0.001, color='r', linestyle='-')\n",
        "ax.set_yscale(\"log\")\n",
        "ax.set(xlabel='Heuristic Step', ylabel=\"Relative Measure Distance (RMD)\", title=\"Distance to Original Subspace when applying the Heuristic\")\n",
        "plt.yscale(\"log\")\n",
        "plt.legend(title=\"RMD Comparison\")\n",
        "plt.xticks(rotation=45)\n",
        "plt.tight_layout()\n",
        "plt.savefig(\"HeuristicSteps.pdf\", format=\"pdf\")\n",
        "plt.show()\n"
      ],
      "metadata": {
        "id": "9hfvsk9ma5us"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Particle Visualization on each step of the Heuristic"
      ],
      "metadata": {
        "id": "w7yCVSqLgWKF"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9Rd_uefnOULE"
      },
      "outputs": [],
      "source": [
        "repetition = 0\n",
        "particle_positions_teacher = teacher_network.get_particles()\n",
        "particles0 = out_dict[\"particles\"][\"fase0\"][repetition]\n",
        "particles1 = out_dict[\"particles\"][\"fase1\"][repetition]\n",
        "particles2 = out_dict[\"particles\"][\"fase2\"][repetition]"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We can visualize the results at the end of step 0:"
      ],
      "metadata": {
        "id": "fmRAqBdLbxZ7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "particle_plot(teacher_network.get_particles(), particles0[-1], equivariant_space_points, title=\"Student Particles (end of step 0)\", c_lims=(-0.55, 0.55), eq_style=(0.35, 0.6), with_lineplots=True, t_style=(3, \"diamond\"))"
      ],
      "metadata": {
        "id": "Y52MXENSa28C"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "At the start of step 1:"
      ],
      "metadata": {
        "id": "BoNDEkW5cA1y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "particle_plot(teacher_network.get_particles(), particles1[0], equivariant_space_points, title=\"Student Particles (start of step 1)\", c_lims=(-0.55, 0.55), eq_style=(0.35, 0.6), with_lineplots=True, t_style=(3, \"diamond\"))"
      ],
      "metadata": {
        "id": "S-W11Wk9cAeM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "End of Step 1:"
      ],
      "metadata": {
        "id": "Az7kFqmjcWn8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "particle_plot(teacher_network.get_particles(), particles1[-1], equivariant_space_points, title=\"Student Particles (end of step 1)\", c_lims=(-0.55, 0.55), eq_style=(0.35, 0.6), with_lineplots=True, t_style=(3, \"diamond\"))"
      ],
      "metadata": {
        "id": "ci68v6bTcJpK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Start of Step2:"
      ],
      "metadata": {
        "id": "SoHb2UL8cZAa"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "particle_plot(teacher_network.get_particles(), particles2[0], equivariant_space_points, title=\"Student Particles (start of step 2)\", c_lims=(-0.55, 0.55), eq_style=(0.35, 0.6), with_lineplots=True, t_style=(3, \"diamond\"))"
      ],
      "metadata": {
        "id": "_RTCdyghcaht"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "End of Step2:"
      ],
      "metadata": {
        "id": "ARf4217Lcaxq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "particle_plot(teacher_network.get_particles(), particles2[-1], equivariant_space_points, title=\"Student Particles (end of step 2)\", c_lims=(-0.55, 0.55), eq_style=(0.35, 0.6), with_lineplots=True, t_style=(3, \"diamond\"))"
      ],
      "metadata": {
        "id": "TVJLWxjFccoC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Saving Particle Plots (optional)"
      ],
      "metadata": {
        "id": "6zgMw2_nURxu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import plotly.graph_objects as go\n",
        "import plotly.io as pio\n",
        "import jax.numpy as jnp\n",
        "\n",
        "def create_plot_list(teacher_coords=None, particle_coords=None, equivariant_coords=None, c_lims=(-2, 2), t_style=(4, \"diamond\"), p_style=(1.5, \"circle\"), eq_style=(0.35, 0.8), with_lineplots=True, showlegend=True, with_colorbar=True):\n",
        "    cmin, cmax = c_lims\n",
        "    plots = []\n",
        "    color_scale = \"Viridis\"#\"curl\"#\"plasma\"#\"turbo\"#\"Rainbow\"#\"Jet\"#\"Viridis\"#\"plasma\"#\"curl\"\n",
        "    showscale = with_colorbar\n",
        "\n",
        "    if equivariant_coords is not None:\n",
        "        ex1, ex2, ex3, ex4 = equivariant_coords\n",
        "        equivariant_space_mesh = go.Mesh3d(\n",
        "            x=ex1, y=ex2, z=ex3, opacity=eq_style[0], intensity=ex4,\n",
        "            colorscale=color_scale, cmax=cmax, cmin=cmin,\n",
        "            name='y', showscale=showscale, hoverinfo='none', showlegend=False)\n",
        "        plots.append(equivariant_space_mesh)\n",
        "\n",
        "    if teacher_coords is not None:\n",
        "        gtx1, gtx2, gtx3, gtx4 = teacher_coords\n",
        "        scatter_gt = go.Scatter3d(\n",
        "            name='Teacher Particles', x=gtx1, y=gtx2, z=gtx3, mode='markers',\n",
        "            marker=dict(size=t_style[0], color=gtx4, cmax=cmax, cmin=cmin,\n",
        "                        colorscale=color_scale, showscale=showscale, symbol=t_style[1]),\n",
        "            hovertemplate='<b>Teacher Particle</b><extra></extra>',\n",
        "            showlegend=showlegend)\n",
        "        plots.append(scatter_gt)\n",
        "\n",
        "    if particle_coords is not None:\n",
        "        x1, x2, x3, x4 = particle_coords\n",
        "        scatter_particles = go.Scatter3d(\n",
        "            name='Student Particles', x=x1, y=x2, z=x3, mode='markers',\n",
        "            marker=dict(size=p_style[0], color=x4, cmax=cmax, cmin=cmin,\n",
        "                        colorscale=color_scale, showscale=showscale, symbol=p_style[1]),\n",
        "            hoverinfo='none', showlegend=showlegend)\n",
        "        plots.append(scatter_particles)\n",
        "\n",
        "    if with_lineplots and teacher_coords is not None:\n",
        "        lines = jnp.einsum('a,bc->abc', jnp.linspace(0, eq_style[1], 50),\n",
        "                           (teacher_coords / jnp.linalg.norm(teacher_coords.T, axis=1)).T).swapaxes(0, 1)\n",
        "        for line in lines:\n",
        "            lx1, lx2, lx3, lx4 = line.T\n",
        "            plots.append(go.Scatter3d(\n",
        "                x=lx1, y=lx2, z=lx3, marker=dict(size=0.01),\n",
        "                line=dict(color='black', width=0.5, dash='dash'),\n",
        "                showlegend=False, hoverinfo=\"none\"))\n",
        "\n",
        "    return plots\n",
        "\n",
        "def get_transpose(x):\n",
        "    return x.T if x is not None else None\n",
        "\n",
        "def particle_plot_save(particle_positions_teacher, particle_positions_model, equivariant_space_limits, title=\"4D plot\", c_lims=(-2, 2), t_style=(4, \"diamond\"), p_style=(1.5, \"circle\"), eq_style=(0.35, 0.8), with_lineplots=True, showlegend=True, with_colorbar=True, show_title=True):\n",
        "    teacher_coords, particle_coords, equivariant_coords = get_transpose(particle_positions_teacher), get_transpose(particle_positions_model), get_transpose(equivariant_space_limits)\n",
        "    plots = create_plot_list(teacher_coords, particle_coords, equivariant_coords, c_lims=c_lims, t_style=t_style, p_style=p_style, eq_style=eq_style, with_lineplots=with_lineplots, showlegend=showlegend, with_colorbar=with_colorbar)\n",
        "\n",
        "    fig = go.Figure()\n",
        "\n",
        "    for plot in plots:\n",
        "        fig.add_trace(plot)\n",
        "\n",
        "    layout_config = dict(\n",
        "        title=dict(text=title if show_title else \"\", x=0.5, font=dict(size=20), automargin=False),\n",
        "        scene=dict(\n",
        "            xaxis_title=\"z1\",\n",
        "            yaxis_title=\"z2\",\n",
        "            zaxis_title=\"z3\",\n",
        "            xaxis_showspikes=False,\n",
        "            yaxis_showspikes=False,\n",
        "            zaxis_showspikes=False,\n",
        "            camera=dict(\n",
        "                up=dict(x=0, y=0, z=1),\n",
        "                center=dict(x=0, y=0, z=0),\n",
        "                eye=dict(x=1.25, y=-1.25, z=1.25) if show_title else dict(x=1.2, y=1.2, z=1.2)\n",
        "            ),\n",
        "            xaxis=dict(showgrid=True, gridcolor='lightgray'),\n",
        "            yaxis=dict(showgrid=True, gridcolor='lightgray'),\n",
        "            zaxis=dict(showgrid=True, gridcolor='lightgray')\n",
        "        ),\n",
        "        width=800,\n",
        "        height=760 if show_title else 720,\n",
        "        autosize=False,\n",
        "        legend=dict(\n",
        "            yanchor=\"bottom\",\n",
        "            y=0.99,\n",
        "            xanchor=\"left\",\n",
        "            x=0.01\n",
        "        )\n",
        "    )\n",
        "\n",
        "    fig.update_layout(layout_config)\n",
        "    if with_colorbar:\n",
        "      fig.add_annotation(\n",
        "          text=\"z4\",\n",
        "          x=1.07,\n",
        "          y=1.02,\n",
        "          showarrow=False,\n",
        "          xref=\"paper\",\n",
        "          yref=\"paper\",\n",
        "          font=dict(size=12)\n",
        "      )\n",
        "    fig.show()\n",
        "    #pio.renderers.default = 'browser'\n",
        "    # Save the figure as a PDF\n",
        "    pretitle = title.replace(\" \", \"\") + \"p1\" if show_title else title.replace(\" \", \"\") + \"p2\"\n",
        "    pio.write_image(fig, pretitle + '.pdf', format='pdf', validate=True)\n",
        "\n",
        "# Example usage:\n",
        "# particle_plot(particle_positions_teacher, particle_positions_model, equivariant_space_limits, title=\"Plot 1\", with_colorbar=True, show_title=True, showlegend=True)\n",
        "# particle_plot(particle_positions_teacher, particle_positions_model, equivariant_space_limits, title=\"Plot 2\", with_colorbar=False, show_title=False, showlegend=False)\n"
      ],
      "metadata": {
        "id": "RTPoYEi2fbiJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "repetition = 0\n",
        "particle_positions_teacher = teacher_network.get_particles()\n",
        "particles0 = out_dict[\"particles\"][\"fase0\"][repetition]\n",
        "particles1 = out_dict[\"particles\"][\"fase1\"][repetition]\n",
        "particles2 = out_dict[\"particles\"][\"fase2\"][repetition]"
      ],
      "metadata": {
        "id": "VjZ5PAmFfwCd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The different hyperparameters might be changed to obtain the different parts of our particle plots:"
      ],
      "metadata": {
        "id": "9OnQPnLCUoFs"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "particle_plot_save(particle_positions_teacher, particles1[-1], equivariant_space_points, title=\"Student Particles (end of step 1)\", c_lims=(-0.65, 0.65), eq_style=(0.35, 0.6), with_lineplots=True, t_style=(3, \"diamond\"), with_colorbar=True, show_title=True, showlegend=True)"
      ],
      "metadata": {
        "id": "EzSPl0wqFx_E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "zhVwtdK0gw6R"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "w7yCVSqLgWKF",
        "6zgMw2_nURxu"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}