{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B-U52DpD4CUn"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pd46RtkK2HOU"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import plotly.express as px\n",
        "import plotly.graph_objects as go\n",
        "\n",
        "from scipy import signal"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "if torch.cuda.is_available():\n",
        "  device = \"cuda\"\n",
        "else:\n",
        "  device = \"cpu\""
      ],
      "metadata": {
        "id": "hPaUVKkItsVa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GwiAU45V2Sq_"
      },
      "source": [
        "# Signal generation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Aga4Xcm_2RPI"
      },
      "outputs": [],
      "source": [
        "# signal generator\n",
        "def generate_signal(x, type, frequency, amplitude):\n",
        "  match type:\n",
        "    case 'sawtooth':\n",
        "      sig = amplitude * signal.sawtooth(frequency * 2.0 * np.pi * x, width=0.5)\n",
        "    case 'square':\n",
        "      sig = amplitude * signal.square(frequency * 2.0 * np.pi * x)\n",
        "    case 'sine':\n",
        "      sig = amplitude * np.sin(frequency * 2.0 * np.pi * x)\n",
        "\n",
        "  return torch.tensor(sig, dtype=torch.float)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# signal generator wrapper\n",
        "def get_signal_generator(type):\n",
        "  def signal_generator(x, frequency, amplitude):\n",
        "    return generate_signal(x, type, frequency, amplitude)\n",
        "  return signal_generator"
      ],
      "metadata": {
        "id": "W0ddOVlagR2F"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x7rKdbqC4O7G"
      },
      "outputs": [],
      "source": [
        "# plot signal\n",
        "def plot_signal(x, sig, xlim=[-2*np.pi, 2*np.pi]):\n",
        "  plt.plot(x, sig)\n",
        "  plt.xlim(xlim)\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SoUDlFdB5jdD"
      },
      "source": [
        "# Architecture"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zNyfhFO569NW"
      },
      "outputs": [],
      "source": [
        "# hosc\n",
        "class HOSC(nn.Module):\n",
        "    def __init__(self, frequency=1.0, amplitude=1.0):\n",
        "        super(HOSC, self).__init__()\n",
        "        self.frequency=frequency\n",
        "        self.amplitude=amplitude\n",
        "\n",
        "    def forward(self, x):\n",
        "        return torch.tanh(self.amplitude * torch.sin(self.frequency * x))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J61IsI_e7Sf8"
      },
      "outputs": [],
      "source": [
        "# siren\n",
        "class SIREN(nn.Module):\n",
        "    def __init__(self, frequency=1.0):\n",
        "        super(SIREN, self).__init__()\n",
        "        self.frequency=frequency\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.frequency * torch.sin(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x5vdHJ8R5jxC"
      },
      "outputs": [],
      "source": [
        "# mlp\n",
        "class TestMLP(nn.Module):\n",
        "  def __init__(self,  activation=nn.ReLU(), in_size=1, out_size=1, hidden_size=2, hidden_width=128):\n",
        "    super(TestMLP, self).__init__()\n",
        "\n",
        "    layers = []\n",
        "    layers.append(nn.Linear(in_size, hidden_width))\n",
        "\n",
        "    for _ in range(hidden_size):\n",
        "      layers.append(nn.Linear(hidden_width, hidden_width))\n",
        "      layers.append(activation)\n",
        "\n",
        "    layers.append(nn.Linear(hidden_width, out_size))\n",
        "    self.layers = nn.Sequential(*layers)\n",
        "\n",
        "  def forward(self, x):\n",
        "    return self.layers(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pfg9_AcGzqEP"
      },
      "outputs": [],
      "source": [
        "# number of parameters\n",
        "def number_params(model):\n",
        "    with torch.no_grad():\n",
        "        return sum(np.fromiter((p.numel() for p in model.parameters() if p.requires_grad), int))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3jXQG6Wa7aUr"
      },
      "source": [
        "# Training procedure"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ykf3qq_h7bqG"
      },
      "outputs": [],
      "source": [
        "def train(net,\n",
        "          lossf,\n",
        "          optimizer,\n",
        "          x_train,\n",
        "          y_train,\n",
        "          x_test,\n",
        "          y_test,\n",
        "          num_epochs,\n",
        "          batch_size,\n",
        "          train_loss_avg,\n",
        "          test_loss_avg,\n",
        "          print_epoch_stats=True):\n",
        "\n",
        "  x_train = x_train.reshape([x_train.size(0), 1])\n",
        "  x_test = x_test.reshape([x_test.size(0), 1])\n",
        "  y_train = y_train.reshape([y_train.size(0), 1])\n",
        "  y_test = y_test.reshape([y_test.size(0), 1])\n",
        "\n",
        "  dataset = torch.utils.data.TensorDataset(x_train, y_train)\n",
        "  dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
        "\n",
        "  net = net.to(device)\n",
        "\n",
        "  for epoch in range(num_epochs):\n",
        "    total_train_loss = 0.0\n",
        "\n",
        "    net.train()\n",
        "    for x_batch, y_batch in dataloader:\n",
        "      y_pred_train = net(x_batch.to(device))\n",
        "      train_loss = lossf(y_batch.to(device), y_pred_train.to(device))\n",
        "      optimizer.zero_grad()\n",
        "      train_loss.backward()\n",
        "      optimizer.step()\n",
        "\n",
        "      total_train_loss += train_loss.item() * x_batch.size(0)\n",
        "\n",
        "    train_loss_avg.append(total_train_loss / len(x_train))\n",
        "\n",
        "    with torch.no_grad():\n",
        "        net.eval()\n",
        "        y_pred_test = net(x_test.to(device))\n",
        "        test_loss = lossf(y_test.to(device), y_pred_test.to(device))\n",
        "        test_loss_avg.append(test_loss.item())\n",
        "\n",
        "    if print_epoch_stats and  epoch % 100 == 99:\n",
        "      print(f'Epoch [{epoch+1}/{num_epochs}] train loss: {train_loss_avg[-1]:.4f}, test loss: {test_loss_avg[-1]:.4f}')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zSH41N0Oz4P0"
      },
      "outputs": [],
      "source": [
        "def plot_loss(train_loss_avg, test_loss_avg):\n",
        "    fig = plt.figure(figsize=(12, 6))\n",
        "    plt.plot(train_loss_avg)\n",
        "    plt.plot(test_loss_avg)\n",
        "    plt.legend(['train', 'valid'])\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FROrQRxL7pY4"
      },
      "source": [
        "# Evaluation logic"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nNI72u9R7qev"
      },
      "outputs": [],
      "source": [
        "def evaluate(net, input, target, xlim=[-2*np.pi, 2*np.pi], lossf=nn.MSELoss(reduction='mean')):\n",
        "  net.eval()\n",
        "  with torch.no_grad():\n",
        "    input = input.reshape([input.size()[0], 1])\n",
        "    target = target.reshape([target.size()[0], 1])\n",
        "\n",
        "    pred = net(input)\n",
        "    loss = lossf(target, pred)\n",
        "\n",
        "    plt.plot(input, target)\n",
        "    plt.plot(input, pred)\n",
        "\n",
        "    plt.xlim(xlim)\n",
        "    plt.show()\n",
        "\n",
        "    return loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pRaHlatiivUA"
      },
      "source": [
        "# Experiment design"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ewpHL-fVi06J"
      },
      "source": [
        "## Experiment #1\n",
        "\n",
        "**Task.** Compare the fidelty of signal representation obtained using **HOSC** and **SIREN** across different signal frequencies and different values of the frequency parameter of the activation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DJf9dd1qiw7I"
      },
      "outputs": [],
      "source": [
        "# trains the mlps on varying frequency signals with varying frequency parameter in the actiavtion\n",
        "def frequency_study(\n",
        "    activation_frequency,\n",
        "    signal_frequency,\n",
        "    x_train,\n",
        "    x_test,\n",
        "    num_epochs,\n",
        "    batch_size,\n",
        "    activation,\n",
        "    signal_generator,\n",
        "    lossf=nn.MSELoss(reduction='mean'),\n",
        "    optimizer_type=torch.optim.Adam\n",
        "    ):\n",
        "\n",
        "  models = []\n",
        "\n",
        "  num_models = len(activation_frequency) * len(signal_frequency)\n",
        "  n = 1\n",
        "\n",
        "  for af in activation_frequency:\n",
        "    for sf in signal_frequency:\n",
        "      print(f'[{n} / {num_models}] Training MLP with ACTIVATION FREQUENCY = {af:.2f}, SIGNAL FREQUENCY = {sf:.2f}')\n",
        "      n += 1\n",
        "\n",
        "      y_train = signal_generator(x_train, sf, 1.0)\n",
        "      y_test = signal_generator(x_test, sf, 1.0)\n",
        "\n",
        "      mlp = TestMLP(activation=activation(frequency=af))\n",
        "      optimizer = optimizer_type(params=mlp.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
        "\n",
        "      train_loss_avg = []\n",
        "      test_loss_avg = []\n",
        "\n",
        "      train(\n",
        "          net=mlp,\n",
        "          optimizer=optimizer,\n",
        "          lossf=lossf,\n",
        "          x_train=x_train,\n",
        "          x_test=x_test,\n",
        "          y_test=y_test,\n",
        "          y_train=y_train,\n",
        "          num_epochs=num_epochs,\n",
        "          batch_size=batch_size,\n",
        "          train_loss_avg=train_loss_avg,\n",
        "          test_loss_avg=test_loss_avg,\n",
        "          print_epoch_stats=False)\n",
        "\n",
        "      models.append((mlp, [af, sf]))\n",
        "\n",
        "  return models"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# evaluates the mlps, gathers the results into a dataframe and plots them\n",
        "def get_results(\n",
        "    models,\n",
        "    x_train,\n",
        "    signal_generator,\n",
        "    lossf=nn.MSELoss(reduction='mean')\n",
        "  ):\n",
        "\n",
        "  loss_data_structure = dict(activation=[], signal=[], mse=[])\n",
        "  loss_data = pd.DataFrame(loss_data_structure)\n",
        "\n",
        "  for mlp, [af, sf] in models:\n",
        "    mlp = mlp.cuda()\n",
        "    y_train = signal_generator(x_train, sf, 1.0)\n",
        "    mlp.eval()\n",
        "    with torch.no_grad():\n",
        "      input = x_train.reshape([x_train.size()[0], 1])\n",
        "      target = y_train.reshape([y_train.size()[0], 1])\n",
        "\n",
        "      pred = mlp(input.cuda())\n",
        "      loss = lossf(target.cuda(), pred)\n",
        "    loss_data.loc[loss_data.shape[0]] = [af, sf, loss.item()]\n",
        "\n",
        "  return loss_data.pivot(index='signal', columns='activation', values='mse')\n"
      ],
      "metadata": {
        "id": "pbM1iFhFV0Po"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_results(\n",
        "    loss_data,\n",
        "    xaxis_label,\n",
        "    xaxis_ticktext,\n",
        "    show_xticks,\n",
        "    yaxis_label,\n",
        "    yaxis_ticktext,\n",
        "    show_yticks,\n",
        "    title_fontsize,\n",
        "    tick_fontsize,\n",
        "    fontsize,\n",
        "    range_color=[0,1],\n",
        "    size=[800,600],\n",
        "    show_scale=False,\n",
        "  ):\n",
        "\n",
        "  data = np.matrix(loss_data.round(3))\n",
        "  rows, cols = data.shape\n",
        "\n",
        "  fig = px.imshow(data, color_continuous_scale='Viridis', range_color=range_color, text_auto=False)\n",
        "\n",
        "  if show_xticks:\n",
        "    x_tickvals = [i for i in range(cols)]\n",
        "  else:\n",
        "    x_tickvals = []\n",
        "\n",
        "  fig.update_xaxes(\n",
        "      title_text=xaxis_label,\n",
        "      tickvals=x_tickvals,\n",
        "      ticktext=xaxis_ticktext,\n",
        "      title_font=dict(size=title_fontsize),\n",
        "      tickfont=dict(size=tick_fontsize)\n",
        "  )\n",
        "\n",
        "  if show_yticks:\n",
        "    y_tickvals = [i for i in range(rows)]\n",
        "  else:\n",
        "    y_tickvals = []\n",
        "\n",
        "  fig.update_yaxes(\n",
        "      title_text=yaxis_label,\n",
        "      tickvals=y_tickvals,\n",
        "      ticktext=yaxis_ticktext,\n",
        "      title_font=dict(size=title_fontsize),\n",
        "      tickfont=dict(size=tick_fontsize)\n",
        "  )\n",
        "\n",
        "  fig.update_layout(\n",
        "      font=dict(size=fontsize),\n",
        "      width=size[0],\n",
        "      height=size[1],\n",
        "      coloraxis_colorbar=dict(x=0.9, tickfont=dict(size=24))\n",
        "  )\n",
        "\n",
        "  fig.update_layout(coloraxis_showscale=show_scale)\n",
        "\n",
        "  fig.show()"
      ],
      "metadata": {
        "id": "H4BMBtl3jOUX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# set parameters\n",
        "learning_rate = 1e-3\n",
        "weight_decay = 0e-5\n",
        "batch_size = 2 ** 10\n",
        "\n",
        "activation_frequency = np.linspace(1, 31, 7)\n",
        "signal_frequency = np.linspace(1, 31, 7)\n",
        "\n",
        "x_train = torch.linspace(-1, 1, 1000)\n",
        "x_test = torch.linspace(-1, 1, 1000)\n",
        "\n",
        "num_epochs = 1000\n",
        "\n",
        "sawtooth_gen = get_signal_generator('sawtooth')\n",
        "\n",
        "def run_experiment_1(activation):\n",
        "   return frequency_study(\n",
        "      activation_frequency=activation_frequency,\n",
        "      signal_frequency=signal_frequency,\n",
        "      x_train=x_train,\n",
        "      x_test=x_test,\n",
        "      num_epochs=num_epochs,\n",
        "      batch_size=batch_size,\n",
        "      activation=activation,\n",
        "      signal_generator=sawtooth_gen\n",
        "  )"
      ],
      "metadata": {
        "id": "FsRQpOHjWFI_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "models_hosc = run_experiment_1(HOSC)"
      ],
      "metadata": {
        "id": "Qqf2QYBRXBuO",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 334
        },
        "outputId": "938a55bd-3b7c-44fa-ff87-951923512b39"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[1 / 49] Training MLP with ACTIVATION FREQUENCY = 1.00, SIGNAL FREQUENCY = 1.00\n"
          ]
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-17-0d501c962efc>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodels_hosc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun_experiment_1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mHOSC\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m<ipython-input-16-1ed08d34515d>\u001b[0m in \u001b[0;36mrun_experiment_1\u001b[0;34m(activation)\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrun_experiment_1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mactivation\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m    return frequency_study(\n\u001b[0m\u001b[1;32m     18\u001b[0m       \u001b[0mactivation_frequency\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mactivation_frequency\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m       \u001b[0msignal_frequency\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msignal_frequency\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-13-13bd8a5c9e0e>\u001b[0m in \u001b[0;36mfrequency_study\u001b[0;34m(activation_frequency, signal_frequency, x_train, x_test, num_epochs, batch_size, activation, signal_generator, lossf, optimizer_type)\u001b[0m\n\u001b[1;32m     32\u001b[0m       \u001b[0mtest_loss_avg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m       train(\n\u001b[0m\u001b[1;32m     35\u001b[0m           \u001b[0mnet\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmlp\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m           \u001b[0moptimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-10-176da5a046a7>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(net, lossf, optimizer, x_train, y_train, x_test, y_test, num_epochs, batch_size, train_loss_avg, test_loss_avg, print_epoch_stats)\u001b[0m\n\u001b[1;32m     20\u001b[0m   \u001b[0mdataloader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m   \u001b[0mnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m   \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1143\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1145\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1146\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1147\u001b[0m     def register_full_backward_pre_hook(\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    795\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    796\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 797\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    798\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    799\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    795\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    796\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 797\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    798\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    799\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    818\u001b[0m             \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    819\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 820\u001b[0;31m                 \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    821\u001b[0m             \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    822\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m   1141\u001b[0m                 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,\n\u001b[1;32m   1142\u001b[0m                             non_blocking, memory_format=convert_to_format)\n\u001b[0;32m-> 1143\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1145\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "loss_data_hosc = get_results(models_hosc, x_train, sawtooth_gen)"
      ],
      "metadata": {
        "id": "7ehtHN12cGd9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_results(\n",
        "    loss_data=loss_data_hosc,\n",
        "    xaxis_label=\"HOSC Frequency\",\n",
        "    xaxis_ticktext=list(activation_frequency),\n",
        "    show_xticks=True,\n",
        "    yaxis_label=\"Signal Frequency\",\n",
        "    yaxis_ticktext=list(signal_frequency),\n",
        "    show_yticks=True,\n",
        "    title_fontsize=36,\n",
        "    tick_fontsize=36,\n",
        "    fontsize=26,\n",
        "    range_color=[0,0.5],\n",
        "    size=[800,600],\n",
        "    show_scale=False\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 617
        },
        "id": "zQ-LwPpgjNLM",
        "outputId": "04cd0adf-be49-4a49-9703-d2c3379c3e19"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<html>\n",
              "<head><meta charset=\"utf-8\" /></head>\n",
              "<body>\n",
              "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
              "        <script charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.24.1.min.js\"></script>                <div id=\"7477315b-28e2-4953-ae6e-52820f47bfea\" class=\"plotly-graph-div\" style=\"height:600px; width:800px;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"7477315b-28e2-4953-ae6e-52820f47bfea\")) {                    Plotly.newPlot(                        \"7477315b-28e2-4953-ae6e-52820f47bfea\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"z\":[[0.0,0.0,0.0,0.0,0.0,0.285,0.329],[0.334,0.007,0.0,0.0,0.371,0.341,0.359],[0.334,0.084,0.001,0.013,0.0,0.307,0.305],[0.334,0.312,0.003,0.008,0.003,0.303,0.324],[0.334,0.331,0.029,0.078,0.128,0.367,0.34],[0.334,0.334,0.07,0.159,0.205,0.109,0.332],[0.334,0.334,0.293,0.273,0.282,0.27,0.335]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"x: %{x}\\u003cbr\\u003ey: %{y}\\u003cbr\\u003ecolor: %{z}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"HOSC Frequency\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2,3,4,5,6],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0]},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"Signal Frequency\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2,3,4,5,6],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0]},\"coloraxis\":{\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"cmin\":0,\"cmax\":0.5,\"colorbar\":{\"tickfont\":{\"size\":24},\"x\":0.9},\"showscale\":false},\"margin\":{\"t\":60},\"font\":{\"size\":26},\"width\":800,\"height\":600},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('7477315b-28e2-4953-ae6e-52820f47bfea');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            </script>        </div>\n",
              "</body>\n",
              "</html>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "models_siren = run_experiment_1(SIREN)"
      ],
      "metadata": {
        "id": "BRxr_PSlXT2V"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "loss_data_siren = get_results(models_siren, x_train, sawtooth_gen)"
      ],
      "metadata": {
        "id": "DOstqYHbcHVu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_results(\n",
        "    loss_data=loss_data_siren,\n",
        "    xaxis_label=\"SIREN Frequency\",\n",
        "    xaxis_ticktext=list(activation_frequency),\n",
        "    show_xticks=True,\n",
        "    yaxis_label=\"\",\n",
        "    yaxis_ticktext=list(signal_frequency),\n",
        "    show_yticks=False,\n",
        "    title_fontsize=36,\n",
        "    tick_fontsize=36,\n",
        "    fontsize=26,\n",
        "    range_color=[0,0.5],\n",
        "    size=[800,600],\n",
        "    show_scale=True\n",
        ")"
      ],
      "metadata": {
        "id": "L1VN6HIxcH9-",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 617
        },
        "outputId": "9bc5dc2a-72e4-448a-84f4-eb07ecf7c4fe"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<html>\n",
              "<head><meta charset=\"utf-8\" /></head>\n",
              "<body>\n",
              "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
              "        <script charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.24.1.min.js\"></script>                <div id=\"96633d59-7990-4107-88d1-a51712e375da\" class=\"plotly-graph-div\" style=\"height:600px; width:800px;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"96633d59-7990-4107-88d1-a51712e375da\")) {                    Plotly.newPlot(                        \"96633d59-7990-4107-88d1-a51712e375da\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"z\":[[0.001,0.004,0.004,0.003,0.001,0.002,0.001],[0.334,0.333,0.331,0.323,0.313,0.309,0.017],[0.334,0.334,0.334,0.333,0.337,0.328,0.531],[0.334,0.334,0.334,0.334,0.333,0.331,14.425],[0.334,0.334,0.334,0.334,0.334,0.333,3.793],[0.334,0.334,0.334,0.334,0.339,0.333,1.446],[0.334,0.334,0.334,0.334,0.335,0.334,0.334]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"x: %{x}\\u003cbr\\u003ey: %{y}\\u003cbr\\u003ecolor: %{z}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"SIREN Frequency\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2,3,4,5,6],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0]},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0]},\"coloraxis\":{\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"cmin\":0,\"cmax\":0.5,\"colorbar\":{\"tickfont\":{\"size\":24},\"x\":0.9},\"showscale\":true},\"margin\":{\"t\":60},\"font\":{\"size\":26},\"width\":800,\"height\":600},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('96633d59-7990-4107-88d1-a51712e375da');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            </script>        </div>\n",
              "</body>\n",
              "</html>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5OJGdilTGhUg"
      },
      "source": [
        "## Experiment #2\n",
        "\n",
        "**Task**. Analyse the effectivness of **HOSC** across different values of the amplitude parameter $b$ and different frequency signals."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vi_ESspYGsMH"
      },
      "outputs": [],
      "source": [
        "def hosc_amplitude_study(\n",
        "    hosc_frequency,\n",
        "    hosc_amplitude,\n",
        "    signal_frequency,\n",
        "    x_train,\n",
        "    x_test,\n",
        "    num_epochs,\n",
        "    batch_size,\n",
        "    signal_generator,\n",
        "    lossf=nn.MSELoss(reduction='mean'),\n",
        "    optimizer_type=torch.optim.Adam\n",
        "    ):\n",
        "\n",
        "  models = []\n",
        "\n",
        "  num_models = len(hosc_amplitude) * len(signal_frequency)\n",
        "  n = 1\n",
        "\n",
        "  for ha in hosc_amplitude:\n",
        "    for sf in signal_frequency:\n",
        "      print(f'[{n} / {num_models}] Training MLP with HOSC AMPLITUDE = {ha:.2f}, SIGNAL FREQUENCY = {sf:.2f}')\n",
        "      n += 1\n",
        "\n",
        "      y_train = signal_generator(x_train, sf, 1.0)\n",
        "      y_test = signal_generator(x_test, sf, 1.0)\n",
        "\n",
        "      mlp = TestMLP(activation=HOSC(amplitude=ha, frequency=hosc_frequency))\n",
        "      optimizer = optimizer_type(params=mlp.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
        "\n",
        "      train_loss_avg = []\n",
        "      test_loss_avg = []\n",
        "\n",
        "      train(\n",
        "          net=mlp,\n",
        "          optimizer=optimizer,\n",
        "          lossf=lossf,\n",
        "          x_train=x_train,\n",
        "          x_test=x_test,\n",
        "          y_test=y_test,\n",
        "          y_train=y_train,\n",
        "          num_epochs=num_epochs,\n",
        "          batch_size=batch_size,\n",
        "          train_loss_avg=train_loss_avg,\n",
        "          test_loss_avg=test_loss_avg,\n",
        "          print_epoch_stats=False)\n",
        "\n",
        "      models.append((mlp, [ha, sf, hosc_frequency]))\n",
        "\n",
        "  return models"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# evaluates the mlps, gathers the results into a dataframe and plots them\n",
        "def get_results(\n",
        "    models,\n",
        "    x_train,\n",
        "    signal_generator,\n",
        "    lossf=nn.MSELoss(reduction='mean'),\n",
        "    plot=True):\n",
        "\n",
        "  loss_data_structure = dict(hosc=[], signal=[], mse=[])\n",
        "  loss_data = pd.DataFrame(loss_data_structure)\n",
        "\n",
        "  for mlp, [ha, sf, hf] in models:\n",
        "    mlp = mlp.to(device)\n",
        "    y_train = signal_generator(x_train, sf, 1.0)\n",
        "    mlp.eval()\n",
        "    with torch.no_grad():\n",
        "      input = x_train.reshape([x_train.size()[0], 1]).to(device)\n",
        "      target = y_train.reshape([y_train.size()[0], 1]).to(device)\n",
        "\n",
        "      pred = mlp(input)\n",
        "      loss = lossf(target, pred)\n",
        "    loss_data.loc[loss_data.shape[0]] = [ha, sf, loss.item()]\n",
        "\n",
        "  loss_data = loss_data.pivot(index='signal', columns='hosc', values='mse')\n",
        "\n",
        "  return loss_data"
      ],
      "metadata": {
        "id": "0qIBSRyqJ2nY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# set parameters\n",
        "learning_rate = 1e-3\n",
        "weight_decay = 0e-5\n",
        "batch_size = 2 ** 10\n",
        "\n",
        "hosc_frequency = 2.5\n",
        "hosc_amplitude = np.linspace(1, 26, 6)\n",
        "signal_frequency = np.linspace(1, 36, 8)\n",
        "\n",
        "x_train = torch.linspace(-1, 1, 1000)\n",
        "x_test = torch.linspace(-1, 1, 1000)\n",
        "\n",
        "sawtooth_gen = get_signal_generator('sawtooth')\n",
        "\n",
        "lossf = nn.MSELoss(reduction='mean')\n",
        "\n",
        "num_epochs = 1000\n",
        "\n",
        "def run_experiment_2(hosc_frequency):\n",
        "   return hosc_amplitude_study(\n",
        "      hosc_frequency=hosc_frequency,\n",
        "      hosc_amplitude=hosc_amplitude,\n",
        "      signal_frequency=signal_frequency,\n",
        "      x_train=x_train,\n",
        "      x_test=x_test,\n",
        "      num_epochs=num_epochs,\n",
        "      batch_size=batch_size,\n",
        "      signal_generator=sawtooth_gen,\n",
        "      lossf=lossf,\n",
        "      optimizer_type=torch.optim.Adam\n",
        "  )"
      ],
      "metadata": {
        "id": "2mfHYn1tZ8ws"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "models = run_experiment_2(hosc_frequency=1)"
      ],
      "metadata": {
        "id": "em-OedIo6jxH",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b7c85218-8112-43f3-d06a-a13442bf74fe"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[1 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 1.00\n",
            "[2 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 6.00\n",
            "[3 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 11.00\n",
            "[4 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 16.00\n",
            "[5 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 21.00\n",
            "[6 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 26.00\n",
            "[7 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 31.00\n",
            "[8 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 36.00\n",
            "[9 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 1.00\n",
            "[10 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 6.00\n",
            "[11 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 11.00\n",
            "[12 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 16.00\n",
            "[13 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 21.00\n",
            "[14 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 26.00\n",
            "[15 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 31.00\n",
            "[16 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 36.00\n",
            "[17 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 1.00\n",
            "[18 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 6.00\n",
            "[19 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 11.00\n",
            "[20 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 16.00\n",
            "[21 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 21.00\n",
            "[22 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 26.00\n",
            "[23 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 31.00\n",
            "[24 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 36.00\n",
            "[25 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 1.00\n",
            "[26 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 6.00\n",
            "[27 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 11.00\n",
            "[28 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 16.00\n",
            "[29 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 21.00\n",
            "[30 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 26.00\n",
            "[31 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 31.00\n",
            "[32 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 36.00\n",
            "[33 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 1.00\n",
            "[34 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 6.00\n",
            "[35 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 11.00\n",
            "[36 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 16.00\n",
            "[37 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 21.00\n",
            "[38 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 26.00\n",
            "[39 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 31.00\n",
            "[40 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 36.00\n",
            "[41 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 1.00\n",
            "[42 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 6.00\n",
            "[43 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 11.00\n",
            "[44 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 16.00\n",
            "[45 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 21.00\n",
            "[46 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 26.00\n",
            "[47 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 31.00\n",
            "[48 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 36.00\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "loss_data_1 = get_results(models, x_train, sawtooth_gen)"
      ],
      "metadata": {
        "id": "0ZRMpaBC6l5O"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_results(\n",
        "    loss_data=loss_data_1,\n",
        "    xaxis_label=\"HOSC Sharpness\",\n",
        "    xaxis_ticktext=list(hosc_amplitude),\n",
        "    show_xticks=True,\n",
        "    yaxis_label=\"Signal Frequen cy\",\n",
        "    yaxis_ticktext=list(signal_frequency),\n",
        "    show_yticks=True,\n",
        "    title_fontsize=36,\n",
        "    tick_fontsize=36,\n",
        "    fontsize=26,\n",
        "    range_color=[0,0.5],\n",
        "    size=[600,600],\n",
        "    show_scale=False\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 617
        },
        "id": "g-qDTESSQnJe",
        "outputId": "e3991743-4980-4c14-e34a-3d15560267a6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<html>\n",
              "<head><meta charset=\"utf-8\" /></head>\n",
              "<body>\n",
              "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
              "        <script charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.24.1.min.js\"></script>                <div id=\"e2854517-c2a8-4b2a-b087-1bf662ca1c3d\" class=\"plotly-graph-div\" style=\"height:600px; width:600px;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"e2854517-c2a8-4b2a-b087-1bf662ca1c3d\")) {                    Plotly.newPlot(                        \"e2854517-c2a8-4b2a-b087-1bf662ca1c3d\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"z\":[[0.0,0.0,0.0,0.0,0.0,0.0],[0.333,0.002,0.01,0.004,0.022,0.006],[0.334,0.275,0.114,0.1,0.034,0.108],[0.334,0.327,0.261,0.201,0.142,0.178],[0.334,0.333,0.289,0.242,0.236,0.188],[0.334,0.334,0.319,0.29,0.278,0.274],[0.334,0.334,0.332,0.319,0.309,0.291],[0.334,0.334,0.334,0.332,0.311,0.286]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"x: %{x}\\u003cbr\\u003ey: %{y}\\u003cbr\\u003ecolor: %{z}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"HOSC Sharpness\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2,3,4,5],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0]},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"Signal Frequency\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2,3,4,5,6,7],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0,36.0]},\"coloraxis\":{\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"cmin\":0,\"cmax\":0.5,\"colorbar\":{\"tickfont\":{\"size\":24},\"x\":0.9},\"showscale\":false},\"margin\":{\"t\":60},\"font\":{\"size\":26},\"width\":600,\"height\":600},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('e2854517-c2a8-4b2a-b087-1bf662ca1c3d');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            </script>        </div>\n",
              "</body>\n",
              "</html>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "models = run_experiment_2(hosc_frequency=2.5)"
      ],
      "metadata": {
        "id": "fsAJckhbZ_ZK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "loss_data_2 = get_results(models, x_train, sawtooth_gen)"
      ],
      "metadata": {
        "id": "fGYHS5PKaBFI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_results(\n",
        "    loss_data=loss_data_2,\n",
        "    xaxis_label=\"\",\n",
        "    xaxis_ticktext=list(hosc_amplitude),\n",
        "    show_xticks=True,\n",
        "    yaxis_label=\"\",\n",
        "    yaxis_ticktext=list(signal_frequency),\n",
        "    show_yticks=False,\n",
        "    title_fontsize=36,\n",
        "    tick_fontsize=36,\n",
        "    fontsize=26,\n",
        "    range_color=[0,0.5],\n",
        "    size=[600,600],\n",
        "    show_scale=False\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 617
        },
        "id": "l_ATlVBVU3qx",
        "outputId": "0eb564bc-bf69-4067-ecd4-a7e01d2d16ea"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<html>\n",
              "<head><meta charset=\"utf-8\" /></head>\n",
              "<body>\n",
              "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
              "        <script charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.24.1.min.js\"></script>                <div id=\"278dfa5d-bc12-48c4-838c-87adc735fbed\" class=\"plotly-graph-div\" style=\"height:600px; width:600px;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"278dfa5d-bc12-48c4-838c-87adc735fbed\")) {                    Plotly.newPlot(                        \"278dfa5d-bc12-48c4-838c-87adc735fbed\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"z\":[[0.0,0.0,0.0,0.0,0.0,0.0],[0.102,0.005,0.003,0.012,0.006,0.052],[0.333,0.038,0.014,0.058,0.133,0.109],[0.334,0.073,0.118,0.112,0.228,0.193],[0.334,0.248,0.158,0.189,0.249,0.297],[0.334,0.333,0.213,0.257,0.253,0.287],[0.334,0.32,0.29,0.295,0.303,0.323],[0.334,0.332,0.307,0.299,0.322,0.387]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"x: %{x}\\u003cbr\\u003ey: %{y}\\u003cbr\\u003ecolor: %{z}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2,3,4,5],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0]},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0,36.0]},\"coloraxis\":{\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"cmin\":0,\"cmax\":0.5,\"colorbar\":{\"tickfont\":{\"size\":24},\"x\":0.9},\"showscale\":false},\"margin\":{\"t\":60},\"font\":{\"size\":26},\"width\":600,\"height\":600},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('278dfa5d-bc12-48c4-838c-87adc735fbed');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            </script>        </div>\n",
              "</body>\n",
              "</html>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "models = run_experiment_2(hosc_frequency=5)"
      ],
      "metadata": {
        "id": "-s3wCnX8beYa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "loss_data_3 = get_results(models, x_train, sawtooth_gen)"
      ],
      "metadata": {
        "id": "fRcZaRidbhtN"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_results(\n",
        "    loss_data=loss_data_3,\n",
        "    xaxis_label=\"\",\n",
        "    xaxis_ticktext=list(hosc_amplitude),\n",
        "    show_xticks=True,\n",
        "    yaxis_label=\"\",\n",
        "    yaxis_ticktext=list(signal_frequency),\n",
        "    show_yticks=False,\n",
        "    title_fontsize=36,\n",
        "    tick_fontsize=36,\n",
        "    fontsize=26,\n",
        "    range_color=[0,0.5],\n",
        "    size=[600,600],\n",
        "    show_scale=False\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 617
        },
        "id": "-YR8-8R3fkG9",
        "outputId": "b8b5fcb5-5ff1-4ac3-af47-0e8dab4efad4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<html>\n",
              "<head><meta charset=\"utf-8\" /></head>\n",
              "<body>\n",
              "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
              "        <script charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.24.1.min.js\"></script>                <div id=\"1ee18d84-3887-4b63-b1fa-d63820188f46\" class=\"plotly-graph-div\" style=\"height:600px; width:600px;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"1ee18d84-3887-4b63-b1fa-d63820188f46\")) {                    Plotly.newPlot(                        \"1ee18d84-3887-4b63-b1fa-d63820188f46\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"z\":[[0.0,0.0,0.048,0.0,0.0,0.0],[0.014,0.354,0.315,0.0,0.002,0.0],[0.175,0.059,0.008,0.182,0.007,0.027],[0.329,0.037,0.643,0.299,0.127,0.204],[0.333,0.204,0.198,0.319,0.32,0.286],[0.334,0.302,0.285,0.327,0.326,0.335],[0.334,0.323,0.319,0.304,0.31,0.331],[0.334,0.334,0.322,0.333,0.316,0.346]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"x: %{x}\\u003cbr\\u003ey: %{y}\\u003cbr\\u003ecolor: %{z}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2,3,4,5],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0]},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0,36.0]},\"coloraxis\":{\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"cmin\":0,\"cmax\":0.5,\"colorbar\":{\"tickfont\":{\"size\":24},\"x\":0.9},\"showscale\":false},\"margin\":{\"t\":60},\"font\":{\"size\":26},\"width\":600,\"height\":600},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('1ee18d84-3887-4b63-b1fa-d63820188f46');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            </script>        </div>\n",
              "</body>\n",
              "</html>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "models = run_experiment_2(hosc_frequency=7.5)"
      ],
      "metadata": {
        "id": "e73_XPBe14Vl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "loss_data_4 = get_results(models, x_train, sawtooth_gen)"
      ],
      "metadata": {
        "id": "pfPBq1EW16Kd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_results(\n",
        "    loss_data=loss_data_4,\n",
        "    xaxis_label=\"\",\n",
        "    xaxis_ticktext=list(hosc_amplitude),\n",
        "    show_xticks=True,\n",
        "    yaxis_label=\"\",\n",
        "    yaxis_ticktext=list(signal_frequency),\n",
        "    show_yticks=False,\n",
        "    title_fontsize=36,\n",
        "    tick_fontsize=36,\n",
        "    fontsize=26,\n",
        "    range_color=[0,0.5],\n",
        "    size=[600,600],\n",
        "    show_scale=False\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 617
        },
        "id": "pQicc5FYfhH2",
        "outputId": "aa16c6b3-19ba-456f-b242-46895c3dd10f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<html>\n",
              "<head><meta charset=\"utf-8\" /></head>\n",
              "<body>\n",
              "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
              "        <script charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.24.1.min.js\"></script>                <div id=\"3c0a25a7-b56e-49fd-b3e7-998a866004b5\" class=\"plotly-graph-div\" style=\"height:600px; width:600px;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"3c0a25a7-b56e-49fd-b3e7-998a866004b5\")) {                    Plotly.newPlot(                        \"3c0a25a7-b56e-49fd-b3e7-998a866004b5\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"z\":[[0.0,0.0,0.002,0.0,0.0,0.013],[0.002,0.0,0.0,0.015,0.0,0.0],[0.003,0.0,0.001,0.114,0.216,0.315],[0.26,0.075,0.011,0.21,0.201,0.318],[0.297,0.212,0.124,0.238,0.344,0.343],[0.324,0.248,0.296,0.323,0.304,0.325],[0.332,0.316,0.319,0.323,0.337,0.325],[0.333,0.332,0.345,0.319,0.328,0.336]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"x: %{x}\\u003cbr\\u003ey: %{y}\\u003cbr\\u003ecolor: %{z}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2,3,4,5],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0]},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0,36.0]},\"coloraxis\":{\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"cmin\":0,\"cmax\":0.5,\"colorbar\":{\"tickfont\":{\"size\":24},\"x\":0.9},\"showscale\":false},\"margin\":{\"t\":60},\"font\":{\"size\":26},\"width\":600,\"height\":600},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('3c0a25a7-b56e-49fd-b3e7-998a866004b5');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            </script>        </div>\n",
              "</body>\n",
              "</html>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "models = run_experiment_2(hosc_frequency=10)"
      ],
      "metadata": {
        "id": "MRG5xNR4hsMW",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "73d1cbf3-f9d6-4cd7-9648-cdc81f933d4f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[1 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 1.00\n",
            "[2 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 6.00\n",
            "[3 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 11.00\n",
            "[4 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 16.00\n",
            "[5 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 21.00\n",
            "[6 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 26.00\n",
            "[7 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 31.00\n",
            "[8 / 48] Training MLP with HOSC AMPLITUDE = 1.00, SIGNAL FREQUENCY = 36.00\n",
            "[9 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 1.00\n",
            "[10 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 6.00\n",
            "[11 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 11.00\n",
            "[12 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 16.00\n",
            "[13 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 21.00\n",
            "[14 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 26.00\n",
            "[15 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 31.00\n",
            "[16 / 48] Training MLP with HOSC AMPLITUDE = 6.00, SIGNAL FREQUENCY = 36.00\n",
            "[17 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 1.00\n",
            "[18 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 6.00\n",
            "[19 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 11.00\n",
            "[20 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 16.00\n",
            "[21 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 21.00\n",
            "[22 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 26.00\n",
            "[23 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 31.00\n",
            "[24 / 48] Training MLP with HOSC AMPLITUDE = 11.00, SIGNAL FREQUENCY = 36.00\n",
            "[25 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 1.00\n",
            "[26 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 6.00\n",
            "[27 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 11.00\n",
            "[28 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 16.00\n",
            "[29 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 21.00\n",
            "[30 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 26.00\n",
            "[31 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 31.00\n",
            "[32 / 48] Training MLP with HOSC AMPLITUDE = 16.00, SIGNAL FREQUENCY = 36.00\n",
            "[33 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 1.00\n",
            "[34 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 6.00\n",
            "[35 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 11.00\n",
            "[36 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 16.00\n",
            "[37 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 21.00\n",
            "[38 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 26.00\n",
            "[39 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 31.00\n",
            "[40 / 48] Training MLP with HOSC AMPLITUDE = 21.00, SIGNAL FREQUENCY = 36.00\n",
            "[41 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 1.00\n",
            "[42 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 6.00\n",
            "[43 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 11.00\n",
            "[44 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 16.00\n",
            "[45 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 21.00\n",
            "[46 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 26.00\n",
            "[47 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 31.00\n",
            "[48 / 48] Training MLP with HOSC AMPLITUDE = 26.00, SIGNAL FREQUENCY = 36.00\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "loss_data_5 = get_results(models, x_train, sawtooth_gen)"
      ],
      "metadata": {
        "id": "-W798BPNhEUv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_results(\n",
        "    loss_data=loss_data_5,\n",
        "    xaxis_label=\"\",\n",
        "    xaxis_ticktext=list(hosc_amplitude),\n",
        "    show_xticks=True,\n",
        "    yaxis_label=\"\",\n",
        "    yaxis_ticktext=list(signal_frequency),\n",
        "    show_yticks=False,\n",
        "    title_fontsize=36,\n",
        "    tick_fontsize=36,\n",
        "    fontsize=26,\n",
        "    range_color=[0,0.5],\n",
        "    size=[600,600],\n",
        "    show_scale=False\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 617
        },
        "id": "t8G4n-J2k2_x",
        "outputId": "1cad7353-7c3e-4ede-d2e8-aac0f425efd9"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<html>\n",
              "<head><meta charset=\"utf-8\" /></head>\n",
              "<body>\n",
              "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
              "        <script charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.24.1.min.js\"></script>                <div id=\"69844eb6-4fa7-41e5-968f-a88380ea205c\" class=\"plotly-graph-div\" style=\"height:600px; width:600px;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"69844eb6-4fa7-41e5-968f-a88380ea205c\")) {                    Plotly.newPlot(                        \"69844eb6-4fa7-41e5-968f-a88380ea205c\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"z\":[[0.385,0.0,0.0,0.199,0.0,0.0],[0.001,0.0,0.026,0.0,0.264,0.272],[0.002,0.0,0.0,0.284,0.238,0.262],[0.009,0.272,0.198,0.317,0.276,0.31],[0.062,0.249,0.251,0.296,0.303,0.334],[0.343,0.309,0.304,0.324,0.333,0.329],[0.243,0.288,0.344,0.332,0.335,0.342],[0.337,0.343,0.336,0.331,0.337,0.338]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"x: %{x}\\u003cbr\\u003ey: %{y}\\u003cbr\\u003ecolor: %{z}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2,3,4,5],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0]},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0,36.0]},\"coloraxis\":{\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"cmin\":0,\"cmax\":0.5,\"colorbar\":{\"tickfont\":{\"size\":24},\"x\":0.9},\"showscale\":false},\"margin\":{\"t\":60},\"font\":{\"size\":26},\"width\":600,\"height\":600},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('69844eb6-4fa7-41e5-968f-a88380ea205c');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            </script>        </div>\n",
              "</body>\n",
              "</html>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "models = run_experiment_2(hosc_frequency=20)"
      ],
      "metadata": {
        "id": "xTUSDvCXxJph"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "loss_data_6 = get_results(models, x_train, sawtooth_gen)"
      ],
      "metadata": {
        "id": "ARYZlUvGxMEZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plot_results(\n",
        "    loss_data=loss_data_6,\n",
        "    xaxis_label=\"\",\n",
        "    xaxis_ticktext=list(hosc_amplitude),\n",
        "    show_xticks=True,\n",
        "    yaxis_label=\"\",\n",
        "    yaxis_ticktext=list(signal_frequency),\n",
        "    show_yticks=False,\n",
        "    title_fontsize=36,\n",
        "    tick_fontsize=36,\n",
        "    fontsize=26,\n",
        "    range_color=[0,0.5],\n",
        "    size=[600,600],\n",
        "    show_scale=True\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 617
        },
        "id": "nsx1KNgMsVE2",
        "outputId": "acdaca4f-b72b-4e3b-f3f6-39f4f6c9f37a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<html>\n",
              "<head><meta charset=\"utf-8\" /></head>\n",
              "<body>\n",
              "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
              "        <script charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.24.1.min.js\"></script>                <div id=\"4b9364fb-1d03-4ecf-9394-b5064c08d302\" class=\"plotly-graph-div\" style=\"height:600px; width:600px;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"4b9364fb-1d03-4ecf-9394-b5064c08d302\")) {                    Plotly.newPlot(                        \"4b9364fb-1d03-4ecf-9394-b5064c08d302\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"z\":[[0.0,0.319,0.34,0.341,0.333,0.339],[0.0,0.32,0.343,0.328,0.332,0.338],[0.0,0.342,0.345,0.339,0.336,0.33],[0.16,0.319,0.334,0.33,0.33,0.332],[0.165,0.35,0.334,0.333,0.344,0.332],[0.194,0.335,0.33,0.33,0.336,0.341],[0.242,0.339,0.339,0.334,0.337,0.336],[0.344,0.337,0.338,0.331,0.337,0.335]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"x: %{x}\\u003cbr\\u003ey: %{y}\\u003cbr\\u003ecolor: %{z}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2,3,4,5],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0]},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"\"},\"tickfont\":{\"size\":36},\"tickvals\":[],\"ticktext\":[1.0,6.0,11.0,16.0,21.0,26.0,31.0,36.0]},\"coloraxis\":{\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"cmin\":0,\"cmax\":0.5,\"colorbar\":{\"tickfont\":{\"size\":24},\"x\":0.9},\"showscale\":true},\"margin\":{\"t\":60},\"font\":{\"size\":26},\"width\":600,\"height\":600},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('4b9364fb-1d03-4ecf-9394-b5064c08d302');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            </script>        </div>\n",
              "</body>\n",
              "</html>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Methodology\n",
        "Our aim is to rigorously evaluate the performance of *HOSC* on 1D-domain both against *SIREN* and in isolation, across varying values of hyperparameters and signal complexities. To this end, we run three experiments, where we assess performance of a two-layer MLP (with the hidden layer width equal to $128$) on synthetic signals. The first two experiments are conducted twice, using a sawtooth and a square wave of varying frequency in the $[-1,1]$ domain as the input signal. The third experiment employs as input a set of randomly generated signals, and the performance metric is aggregated to obtain an averaged outcome.\n",
        "\n",
        " **Experiment 1.**  In the first experiment, we train the MLPs on varying signal and activation frequencies (keeping the amplitude parameter constant in the *HOSC*'s case). Signal frequencies range from $1$ to $31$ (incremented by $5$), and activation frequencies range from $1$ to $26$ (also incremented by $5$). Consequently, the outcome are two $7 \\times 6$  matrices, one for *HOSC* and one for *SIREN*, whose values are the mean-squared error (MSE) of the trained MLP, represented as heatmaps.\n",
        "\n",
        "**Experiment 2.** In the second experiment, we analyze the effect of the amplitude parameter $b$ of *HOSC* on its performance.  Signal frequencies again range from $1$ to $31$ (incremented by $5$), and the amplitude ranges from $1$ to $26$ (also incremented by $5$). The frequency is kept constant, which yields $7 \\times 6$ outcome MSE heatmap. We repeat the experiment for different values of the *HOSC* frequency parameter: $[1, 2.5, 5, 7.5, 10, 20]$.   \n",
        "\n",
        "**Experiment 3.**  This experiment assesses *HOSC*'s performance on random signals $\\xi$ constructed as a probabilistically weighted sum of sines with random frequencies $\\omega_i$ and a random offset $\\pi_i$. Formally\n",
        "$$\n",
        "\\xi(x) = \\sum_{i=1}^{n} \\lambda_{i}\\sin (\\omega_{i} x + \\pi_{i})\\, ,\n",
        "$$\n",
        "where $\\| \\lambda \\|_{1}= 1$, and $\\lambda_{i} > 0$. Moreover, $\\pi_{i}\\in [0, 2\\pi]$. In our experiment, we let $n=100$ and $\\omega_{i} \\in [0, 50]$.\n",
        "\n",
        "We let the *HOSC* amplitude and frequency range from $1$ to $26$ (incremented by $5$), and sample $100$ random signals. We train a separate MLP on each one of them, and average the loss to create a $6 \\times 6$ output MSE heatmap.\n",
        "\n"
      ],
      "metadata": {
        "id": "Sk-RJPzXlGYQ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "data = {\n",
        "  'Column1': [1.1, 2.2, 3.3],\n",
        "  'Column2': [4.4, 5.5, 6.6],\n",
        "  'Column3': [7.7, 8.8, 9.9]\n",
        "}\n",
        "\n",
        "data = np.matrix(pd.DataFrame(data))\n",
        "\n",
        "plot_results(\n",
        "    data,\n",
        "    \"HOSC Frequency\",\n",
        "    [\"a\", \"b\", \"c\"],\n",
        "    \"Signal Frequency\",\n",
        "    [1, 2, 3],\n",
        "    36,\n",
        "    36,\n",
        "    26,\n",
        "    range_color=[0,10]\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 617
        },
        "id": "hfxCTRoxj_dh",
        "outputId": "8b0252b6-4cd5-4ff3-8f99-15310036f8a0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<html>\n",
              "<head><meta charset=\"utf-8\" /></head>\n",
              "<body>\n",
              "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
              "        <script charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.24.1.min.js\"></script>                <div id=\"f41d2480-49cf-4713-8bcc-df932a04d1ec\" class=\"plotly-graph-div\" style=\"height:600px; width:800px;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"f41d2480-49cf-4713-8bcc-df932a04d1ec\")) {                    Plotly.newPlot(                        \"f41d2480-49cf-4713-8bcc-df932a04d1ec\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"texttemplate\":\"%{z}\",\"z\":[[1.1,4.4,7.7],[2.2,5.5,8.8],[3.3,6.6,9.9]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"x: %{x}\\u003cbr\\u003ey: %{y}\\u003cbr\\u003ecolor: %{z}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"HOSC Frequency\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2],\"ticktext\":[\"a\",\"b\",\"c\"]},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"font\":{\"size\":36},\"text\":\"Signal Frequency\"},\"tickfont\":{\"size\":36},\"tickvals\":[0,1,2],\"ticktext\":[1,2,3]},\"coloraxis\":{\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"cmin\":0,\"cmax\":10,\"colorbar\":{\"tickfont\":{\"size\":24},\"x\":0.9}},\"margin\":{\"t\":60},\"font\":{\"size\":26},\"width\":800,\"height\":600},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('f41d2480-49cf-4713-8bcc-df932a04d1ec');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            </script>        </div>\n",
              "</body>\n",
              "</html>"
            ]
          },
          "metadata": {}
        }
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}