{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gUA_0gdG5ckZ"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fh3d41T-4yuy"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "import os\n",
        "\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import plotly.express as px\n",
        "import plotly.graph_objects as go\n",
        "\n",
        "from scipy import signal"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YSX-h6Jycj9t",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b53f1580-2983-487a-b1f3-ffce351d1b1e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/gdrive\n"
          ]
        }
      ],
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive', force_remount=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T1kCqMjL5dq6"
      },
      "source": [
        "# Signal generation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7dwn3P2W5Va2"
      },
      "outputs": [],
      "source": [
        "# signal generator\n",
        "def random_signal_generator(num_components, fr_range):\n",
        "    a, b = fr_range\n",
        "    frequencies = a + (b - a) * np.random.rand(num_components)\n",
        "    offset = 2 * np.pi * np.random.rand(num_components)\n",
        "    unnormalized = np.random.rand(num_components)\n",
        "    probabilities = unnormalized / unnormalized.sum()\n",
        "\n",
        "    def signal(x):\n",
        "        raw = np.sin(x * frequencies + offset)\n",
        "        return np.dot(probabilities, raw)\n",
        "\n",
        "    signal = np.vectorize(signal)\n",
        "\n",
        "    def random_signal(x):\n",
        "        unnormalized = signal(x)\n",
        "        return torch.tensor(unnormalized / np.linalg.norm(unnormalized, ord=np.inf), dtype=torch.float)\n",
        "\n",
        "    return random_signal\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M4qDTkzK5uQp"
      },
      "outputs": [],
      "source": [
        "# signal generator wrapper\n",
        "def get_signal_generator(num_components, fr_range):\n",
        "  def signal_generator():\n",
        "    return random_signal_generator(num_components, fr_range)\n",
        "  return signal_generator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nBkQurNT-2TQ"
      },
      "outputs": [],
      "source": [
        "def plot_signal(x, y):\n",
        "  fig = go.Figure(data=go.Scatter(x=x, y=y, mode='lines'))\n",
        "  fig.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qywTFy88_Hrs"
      },
      "source": [
        "# Architecture"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g10EeWzL_Jvt"
      },
      "outputs": [],
      "source": [
        "# hosc\n",
        "class HOSC(nn.Module):\n",
        "    def __init__(self, frequency=1.0, amplitude=1.0):\n",
        "        super(HOSC, self).__init__()\n",
        "        self.frequency=frequency\n",
        "        self.amplitude=amplitude\n",
        "\n",
        "    def forward(self, x):\n",
        "        return torch.tanh(self.amplitude * torch.sin(self.frequency * x))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RbiiOLRr_LKB"
      },
      "outputs": [],
      "source": [
        "# siren\n",
        "class SIREN(nn.Module):\n",
        "    def __init__(self, frequency=1.0):\n",
        "        super(SIREN, self).__init__()\n",
        "        self.frequency=frequency\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.frequency * torch.sin(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "57PNvfLO_NT4"
      },
      "outputs": [],
      "source": [
        "# mlp\n",
        "class TestMLP(nn.Module):\n",
        "  def __init__(self,  activation=nn.ReLU(), in_size=1, out_size=1, hidden_size=2, hidden_width=128):\n",
        "    super(TestMLP, self).__init__()\n",
        "\n",
        "    layers = []\n",
        "    layers.append(nn.Linear(in_size, hidden_width))\n",
        "\n",
        "    for _ in range(hidden_size):\n",
        "      layers.append(nn.Linear(hidden_width, hidden_width))\n",
        "      layers.append(activation)\n",
        "\n",
        "    layers.append(nn.Linear(hidden_width, out_size))\n",
        "    self.layers = nn.Sequential(*layers)\n",
        "\n",
        "  def forward(self, x):\n",
        "    return self.layers(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WVd9c-mu_Oo5"
      },
      "outputs": [],
      "source": [
        "# number of parameters\n",
        "def number_params(model):\n",
        "    with torch.no_grad():\n",
        "        return sum(np.fromiter((p.numel() for p in model.parameters() if p.requires_grad), int))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TivAa2GR_qpo"
      },
      "source": [
        "# Training procedure"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B7YxLK9H_pkY"
      },
      "outputs": [],
      "source": [
        "def train(net,\n",
        "          lossf,\n",
        "          optimizer,\n",
        "          x_train,\n",
        "          y_train,\n",
        "          x_test,\n",
        "          y_test,\n",
        "          num_epochs,\n",
        "          batch_size,\n",
        "          train_loss_avg,\n",
        "          test_loss_avg,\n",
        "          print_epoch_stats=True):\n",
        "\n",
        "  x_train = x_train.reshape([x_train.size(0), 1])\n",
        "  x_test = x_test.reshape([x_test.size(0), 1])\n",
        "  y_train = y_train.reshape([y_train.size(0), 1])\n",
        "  y_test = y_test.reshape([y_test.size(0), 1])\n",
        "\n",
        "  dataset = torch.utils.data.TensorDataset(x_train, y_train)\n",
        "  dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
        "\n",
        "  for epoch in range(num_epochs):\n",
        "    total_train_loss = 0.0\n",
        "\n",
        "    net.train()\n",
        "    for x_batch, y_batch in dataloader:\n",
        "      y_pred_train = net(x_batch)\n",
        "      train_loss = lossf(y_batch, y_pred_train)\n",
        "      optimizer.zero_grad()\n",
        "      train_loss.backward()\n",
        "      optimizer.step()\n",
        "\n",
        "      total_train_loss += train_loss.item() * x_batch.size(0)\n",
        "\n",
        "    train_loss_avg.append(total_train_loss / len(x_train))\n",
        "\n",
        "    with torch.no_grad():\n",
        "        net.eval()\n",
        "        y_pred_test = net(x_test)\n",
        "        test_loss = lossf(y_test, y_pred_test)\n",
        "        test_loss_avg.append(test_loss.item())\n",
        "\n",
        "    if print_epoch_stats and  epoch % 100 == 99:\n",
        "      print(f'Epoch [{epoch+1}/{num_epochs}] train loss: {train_loss_avg[-1]:.4f}, test loss: {test_loss_avg[-1]:.4f}')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LhsOqSOsAEJ0"
      },
      "outputs": [],
      "source": [
        "def plot_loss(train_loss_avg, test_loss_avg):\n",
        "    fig = plt.figure(figsize=(12, 6))\n",
        "    plt.plot(train_loss_avg)\n",
        "    plt.plot(test_loss_avg)\n",
        "    plt.legend(['train', 'valid'])\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "upEaOAaBAHJ2"
      },
      "source": [
        "# Evaluation logic"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iBkCaCR6AGaG"
      },
      "outputs": [],
      "source": [
        "def evaluate(net, input, target, xlim=[-2*np.pi, 2*np.pi], lossf=nn.MSELoss(reduction='mean')):\n",
        "  net.eval()\n",
        "  with torch.no_grad():\n",
        "    input = input.reshape([input.size()[0], 1])\n",
        "    target = target.reshape([target.size()[0], 1])\n",
        "\n",
        "    pred = net(input)\n",
        "    loss = lossf(target, pred)\n",
        "\n",
        "    plt.plot(input, target)\n",
        "    plt.plot(input, pred)\n",
        "\n",
        "    plt.xlim(xlim)\n",
        "    plt.show()\n",
        "\n",
        "    return loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1DtQhUqcAZih"
      },
      "source": [
        "# Experiment design"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eIUNk9qtA5Y-"
      },
      "source": [
        "## Experiment #3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-pHXDYrpdGl7"
      },
      "outputs": [],
      "source": [
        "# save intermediate results\n",
        "def save_data_frame(df, filename, folder_path='/content/gdrive/My Drive/HOSC/experiment_3_SIREN'):\n",
        "  full_path = os.path.join(folder_path, filename)\n",
        "  df.to_csv(full_path, index=False)\n",
        "\n",
        "# read all dataframes from a directoryt\n",
        "def read_data_frames(folder_path='/content/gdrive/My Drive/HOSC/experiment_3_SIREN'):\n",
        "  files = [f for f in os.listdir(folder_path) if f.endswith('.csv')]\n",
        "  dataframes = []\n",
        "  for f in files:\n",
        "    full_path = os.path.join(folder_path, f)\n",
        "    dataframes.append(pd.read_csv(full_path))\n",
        "  return dataframes\n",
        "\n",
        "# count csv files\n",
        "def count_csv_files(folder_path='/content/gdrive/My Drive/HOSC/experiment_3_SIREN'):\n",
        "    files = os.listdir(folder_path)\n",
        "    csv_count = sum(1 for file in files if file.endswith('.csv'))\n",
        "    return csv_count"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9JoRTRNgEUfy"
      },
      "outputs": [],
      "source": [
        "def random_signal_study(\n",
        "    frequency,\n",
        "    x_train,\n",
        "    x_test,\n",
        "    num_epochs,\n",
        "    batch_size,\n",
        "    signal_generator,\n",
        "    num_runs=10,\n",
        "    lossf=nn.MSELoss(reduction='mean'),\n",
        "    optimizer_type=torch.optim.Adam\n",
        "  ):\n",
        "\n",
        "  loss_data_structure = dict(frequency=[], mse=[])\n",
        "\n",
        "  num_models_single_run = len(frequency)\n",
        "\n",
        "  while count_csv_files() < num_runs:\n",
        "    print(f'Random signal [{count_csv_files() + 1} / {num_runs}]')\n",
        "\n",
        "    loss_data = pd.DataFrame(loss_data_structure)\n",
        "\n",
        "    n = 1 # count models in a single run\n",
        "    for fr in frequency:\n",
        "      print(f'[{n} / {num_models_single_run}] Training MLP with SIREN FREQUENCY = {fr:.2f}')\n",
        "      n += 1\n",
        "\n",
        "      random_signal = signal_generator()\n",
        "\n",
        "      y_train = random_signal(x_train)\n",
        "      y_test = random_signal(x_test)\n",
        "\n",
        "      mlp = TestMLP(activation=SIREN(frequency=fr))\n",
        "      optimizer = optimizer_type(params=mlp.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
        "\n",
        "      train_loss_avg = []\n",
        "      test_loss_avg = []\n",
        "\n",
        "      train(\n",
        "          net=mlp,\n",
        "          optimizer=optimizer,\n",
        "          lossf=lossf,\n",
        "          x_train=x_train,\n",
        "          x_test=x_test,\n",
        "          y_test=y_test,\n",
        "          y_train=y_train,\n",
        "          num_epochs=num_epochs,\n",
        "          batch_size=batch_size,\n",
        "          train_loss_avg=train_loss_avg,\n",
        "          test_loss_avg=test_loss_avg,\n",
        "          print_epoch_stats=False)\n",
        "\n",
        "      mlp.eval()\n",
        "      with torch.no_grad():\n",
        "        input = x_train.reshape([x_train.size()[0], 1])\n",
        "        target = y_train.reshape([y_train.size()[0], 1])\n",
        "\n",
        "        pred = mlp(input)\n",
        "        loss = lossf(target, pred)\n",
        "\n",
        "      loss_data.loc[loss_data.shape[0]] = [fr, float(loss)]\n",
        "\n",
        "    run_count = count_csv_files\n",
        "    save_data_frame(loss_data, filename=f'loss_{count_csv_files()}.csv')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ykjS6mi77E1P"
      },
      "outputs": [],
      "source": [
        "def get_results(plot=False):\n",
        "  dfs = read_data_frames()\n",
        "  aggregate = sum(dfs).div(num_runs)\n",
        "\n",
        "  aggregate = aggregate.set_index(\"frequency\")\n",
        "  # aggregate.index = list(map(str, list(hosc_frequency)))\n",
        "\n",
        "  if plot:\n",
        "    fig = px.imshow(aggregate, color_continuous_scale='Viridis', range_color=[0, 0.5])\n",
        "\n",
        "    fig.update_layout(\n",
        "      # xaxis_title=\"HOSC Amplitude\",\n",
        "      yaxis_title=\"SIREN Frequency\",\n",
        "      width=250,\n",
        "      height=800\n",
        "    )\n",
        "\n",
        "    fig.update_layout(width=800, height=600)\n",
        "\n",
        "    fig.show()\n",
        "\n",
        "  return aggregate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s9u9urMGnI3K"
      },
      "outputs": [],
      "source": [
        "# set parameters\n",
        "learning_rate = 1e-3\n",
        "weight_decay = 0e-5\n",
        "batch_size = 2 ** 10\n",
        "\n",
        "siren_frequency = np.linspace(1, 26, 6)\n",
        "\n",
        "x_train = torch.linspace(-1, 1, 1000)\n",
        "x_test = torch.linspace(-1, 1, 1000)\n",
        "\n",
        "num_epochs = 1000\n",
        "\n",
        "num_components = 100\n",
        "frequency_range = [0,100]\n",
        "\n",
        "num_runs = 100\n",
        "\n",
        "random_gen = get_signal_generator(num_components, frequency_range)\n",
        "\n",
        "def run_experiment_3():\n",
        "   return random_signal_study(\n",
        "      frequency=siren_frequency ,\n",
        "      x_train=x_train,\n",
        "      x_test=x_test,\n",
        "      num_epochs=num_epochs,\n",
        "      batch_size=batch_size,\n",
        "      signal_generator=random_gen,\n",
        "      num_runs=num_runs\n",
        "  )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "asYqRj7S7C2t"
      },
      "outputs": [],
      "source": [
        "loss_dfs = run_experiment_3()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 866
        },
        "id": "eMFVgs89F3w7",
        "outputId": "8de9e3fc-2145-4197-9104-66379de4daf5"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<html>\n",
              "<head><meta charset=\"utf-8\" /></head>\n",
              "<body>\n",
              "    <div>            <script src=\"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_SVG\"></script><script type=\"text/javascript\">if (window.MathJax && window.MathJax.Hub && window.MathJax.Hub.Config) {window.MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}</script>                <script type=\"text/javascript\">window.PlotlyConfig = {MathJaxConfig: 'local'};</script>\n",
              "        <script charset=\"utf-8\" src=\"https://cdn.plot.ly/plotly-2.24.1.min.js\"></script>                <div id=\"4bfe49ff-e24b-4d84-ae01-34a5d13889fc\" class=\"plotly-graph-div\" style=\"height:600px; width:800px;\"></div>            <script type=\"text/javascript\">                                    window.PLOTLYENV=window.PLOTLYENV || {};                                    if (document.getElementById(\"4bfe49ff-e24b-4d84-ae01-34a5d13889fc\")) {                    Plotly.newPlot(                        \"4bfe49ff-e24b-4d84-ae01-34a5d13889fc\",                        [{\"coloraxis\":\"coloraxis\",\"name\":\"0\",\"x\":[\"mse\"],\"y\":[1.0,6.0,11.0,16.0,21.0,26.0],\"z\":[[0.08302580919116735],[0.1130488111078739],[0.110578815639019],[0.20104930724948644],[1.736044105961919],[0.9178806383907795]],\"type\":\"heatmap\",\"xaxis\":\"x\",\"yaxis\":\"y\",\"hovertemplate\":\"x: %{x}\\u003cbr\\u003efrequency: %{y}\\u003cbr\\u003ecolor: %{z}\\u003cextra\\u003e\\u003c\\u002fextra\\u003e\"}],                        {\"template\":{\"data\":{\"histogram2dcontour\":[{\"type\":\"histogram2dcontour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"choropleth\":[{\"type\":\"choropleth\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"histogram2d\":[{\"type\":\"histogram2d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmap\":[{\"type\":\"heatmap\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"heatmapgl\":[{\"type\":\"heatmapgl\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"contourcarpet\":[{\"type\":\"contourcarpet\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"contour\":[{\"type\":\"contour\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"surface\":[{\"type\":\"surface\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"},\"colorscale\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]]}],\"mesh3d\":[{\"type\":\"mesh3d\",\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}],\"scatter\":[{\"fillpattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2},\"type\":\"scatter\"}],\"parcoords\":[{\"type\":\"parcoords\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolargl\":[{\"type\":\"scatterpolargl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"bar\":[{\"error_x\":{\"color\":\"#2a3f5f\"},\"error_y\":{\"color\":\"#2a3f5f\"},\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"bar\"}],\"scattergeo\":[{\"type\":\"scattergeo\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterpolar\":[{\"type\":\"scatterpolar\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"histogram\":[{\"marker\":{\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"histogram\"}],\"scattergl\":[{\"type\":\"scattergl\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatter3d\":[{\"type\":\"scatter3d\",\"line\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattermapbox\":[{\"type\":\"scattermapbox\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scatterternary\":[{\"type\":\"scatterternary\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"scattercarpet\":[{\"type\":\"scattercarpet\",\"marker\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}}}],\"carpet\":[{\"aaxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"baxis\":{\"endlinecolor\":\"#2a3f5f\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"minorgridcolor\":\"white\",\"startlinecolor\":\"#2a3f5f\"},\"type\":\"carpet\"}],\"table\":[{\"cells\":{\"fill\":{\"color\":\"#EBF0F8\"},\"line\":{\"color\":\"white\"}},\"header\":{\"fill\":{\"color\":\"#C8D4E3\"},\"line\":{\"color\":\"white\"}},\"type\":\"table\"}],\"barpolar\":[{\"marker\":{\"line\":{\"color\":\"#E5ECF6\",\"width\":0.5},\"pattern\":{\"fillmode\":\"overlay\",\"size\":10,\"solidity\":0.2}},\"type\":\"barpolar\"}],\"pie\":[{\"automargin\":true,\"type\":\"pie\"}]},\"layout\":{\"autotypenumbers\":\"strict\",\"colorway\":[\"#636efa\",\"#EF553B\",\"#00cc96\",\"#ab63fa\",\"#FFA15A\",\"#19d3f3\",\"#FF6692\",\"#B6E880\",\"#FF97FF\",\"#FECB52\"],\"font\":{\"color\":\"#2a3f5f\"},\"hovermode\":\"closest\",\"hoverlabel\":{\"align\":\"left\"},\"paper_bgcolor\":\"white\",\"plot_bgcolor\":\"#E5ECF6\",\"polar\":{\"bgcolor\":\"#E5ECF6\",\"angularaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"radialaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"ternary\":{\"bgcolor\":\"#E5ECF6\",\"aaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"baxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"},\"caxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\"}},\"coloraxis\":{\"colorbar\":{\"outlinewidth\":0,\"ticks\":\"\"}},\"colorscale\":{\"sequential\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"sequentialminus\":[[0.0,\"#0d0887\"],[0.1111111111111111,\"#46039f\"],[0.2222222222222222,\"#7201a8\"],[0.3333333333333333,\"#9c179e\"],[0.4444444444444444,\"#bd3786\"],[0.5555555555555556,\"#d8576b\"],[0.6666666666666666,\"#ed7953\"],[0.7777777777777778,\"#fb9f3a\"],[0.8888888888888888,\"#fdca26\"],[1.0,\"#f0f921\"]],\"diverging\":[[0,\"#8e0152\"],[0.1,\"#c51b7d\"],[0.2,\"#de77ae\"],[0.3,\"#f1b6da\"],[0.4,\"#fde0ef\"],[0.5,\"#f7f7f7\"],[0.6,\"#e6f5d0\"],[0.7,\"#b8e186\"],[0.8,\"#7fbc41\"],[0.9,\"#4d9221\"],[1,\"#276419\"]]},\"xaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"yaxis\":{\"gridcolor\":\"white\",\"linecolor\":\"white\",\"ticks\":\"\",\"title\":{\"standoff\":15},\"zerolinecolor\":\"white\",\"automargin\":true,\"zerolinewidth\":2},\"scene\":{\"xaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"yaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2},\"zaxis\":{\"backgroundcolor\":\"#E5ECF6\",\"gridcolor\":\"white\",\"linecolor\":\"white\",\"showbackground\":true,\"ticks\":\"\",\"zerolinecolor\":\"white\",\"gridwidth\":2}},\"shapedefaults\":{\"line\":{\"color\":\"#2a3f5f\"}},\"annotationdefaults\":{\"arrowcolor\":\"#2a3f5f\",\"arrowhead\":0,\"arrowwidth\":1},\"geo\":{\"bgcolor\":\"white\",\"landcolor\":\"#E5ECF6\",\"subunitcolor\":\"white\",\"showland\":true,\"showlakes\":true,\"lakecolor\":\"white\"},\"title\":{\"x\":0.05},\"mapbox\":{\"style\":\"light\"}}},\"xaxis\":{\"anchor\":\"y\",\"domain\":[0.0,1.0],\"scaleanchor\":\"y\",\"constrain\":\"domain\"},\"yaxis\":{\"anchor\":\"x\",\"domain\":[0.0,1.0],\"autorange\":\"reversed\",\"constrain\":\"domain\",\"title\":{\"text\":\"SIREN Frequency\"}},\"coloraxis\":{\"colorscale\":[[0.0,\"#440154\"],[0.1111111111111111,\"#482878\"],[0.2222222222222222,\"#3e4989\"],[0.3333333333333333,\"#31688e\"],[0.4444444444444444,\"#26828e\"],[0.5555555555555556,\"#1f9e89\"],[0.6666666666666666,\"#35b779\"],[0.7777777777777778,\"#6ece58\"],[0.8888888888888888,\"#b5de2b\"],[1.0,\"#fde725\"]],\"cmin\":0,\"cmax\":0.5},\"margin\":{\"t\":60},\"width\":800,\"height\":600},                        {\"responsive\": true}                    ).then(function(){\n",
              "                            \n",
              "var gd = document.getElementById('4bfe49ff-e24b-4d84-ae01-34a5d13889fc');\n",
              "var x = new MutationObserver(function (mutations, observer) {{\n",
              "        var display = window.getComputedStyle(gd).display;\n",
              "        if (!display || display === 'none') {{\n",
              "            console.log([gd, 'removed!']);\n",
              "            Plotly.purge(gd);\n",
              "            observer.disconnect();\n",
              "        }}\n",
              "}});\n",
              "\n",
              "// Listen for the removal of the full notebook cells\n",
              "var notebookContainer = gd.closest('#notebook-container');\n",
              "if (notebookContainer) {{\n",
              "    x.observe(notebookContainer, {childList: true});\n",
              "}}\n",
              "\n",
              "// Listen for the clearing of the current output cell\n",
              "var outputEl = gd.closest('.output');\n",
              "if (outputEl) {{\n",
              "    x.observe(outputEl, {childList: true});\n",
              "}}\n",
              "\n",
              "                        })                };                            </script>        </div>\n",
              "</body>\n",
              "</html>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "                mse\n",
              "frequency          \n",
              "1.0        0.083026\n",
              "6.0        0.113049\n",
              "11.0       0.110579\n",
              "16.0       0.201049\n",
              "21.0       1.736044\n",
              "26.0       0.917881"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-899b4481-bc96-46a0-b803-e4305c1eb688\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>mse</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>frequency</th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>1.0</th>\n",
              "      <td>0.083026</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6.0</th>\n",
              "      <td>0.113049</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>11.0</th>\n",
              "      <td>0.110579</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>16.0</th>\n",
              "      <td>0.201049</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>21.0</th>\n",
              "      <td>1.736044</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>26.0</th>\n",
              "      <td>0.917881</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-899b4481-bc96-46a0-b803-e4305c1eb688')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-899b4481-bc96-46a0-b803-e4305c1eb688 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-899b4481-bc96-46a0-b803-e4305c1eb688');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-b75060ea-acfa-45bf-b06b-99cd20a9f008\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-b75060ea-acfa-45bf-b06b-99cd20a9f008')\"\n",
              "            title=\"Suggest charts.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-b75060ea-acfa-45bf-b06b-99cd20a9f008 button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "    </div>\n",
              "  </div>\n"
            ]
          },
          "metadata": {},
          "execution_count": 40
        }
      ],
      "source": [
        "get_results(plot=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "27sXwKEW_B0Q"
      },
      "source": [
        "# Test space"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eAxl8PEL8Oa6"
      },
      "outputs": [],
      "source": [
        "# plotting test\n",
        "random_gen = get_signal_generator(100, [0,100])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y13d04Mk6ZBE"
      },
      "outputs": [],
      "source": [
        "random_signal = random_gen()\n",
        "\n",
        "x = np.linspace(-1, 1, 1000)\n",
        "y = random_signal(x)\n",
        "\n",
        "plot_signal(x,y)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}