{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Compression of Raw Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "import random\n",
    "import torch\n",
    "from scipy.special import erf\n",
    "import matplotlib.pyplot as plt\n",
    "from math import*\n",
    "from scipy.integrate import quad as itg\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import pandas as pd\n",
    "from src.config import RESULT_DIR, FIGURE_DIR\n",
    "from utils.experiments import dump_pickle, load_pickle, concat_dataframes\n",
    "from src.config import RESULT_DIR as empirics_dir\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def group(df):\n",
    "    return df.groupby(['d','L','r','delta','sigma','omega','attention_lmbda','linear_lmbda','alpha','informed','optim'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d     L  r  delta  sigma  omega  attention_lmbda  linear_lmbda  alpha     informed  optim\n",
      "1000  2  1  0.4    0.3    0.3    0.01             0.0001        0.010000  False     GD       24\n",
      "                                                                          True      GD       24\n",
      "                                                                0.092917  False     GD       24\n",
      "                                                                          True      GD       24\n",
      "                                                                0.175833  False     GD       24\n",
      "                                                                                             ..\n",
      "                   1.0    0.3    0.01             0.0001        1.834167  True      GD       24\n",
      "                                                                1.917083  False     GD       24\n",
      "                                                                          True      GD       24\n",
      "                                                                2.000000  False     GD       24\n",
      "                                                                          True      GD       24\n",
      "Length: 1600, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "df = concat_dataframes([load_pickle(f) for f in (empirics_dir / '01_paper_toy_example_phase_diagram_pos_init').glob('*.pkl')]) # replace with 01_paper_toy_example_phase_diagram\n",
    "df['linear_gen_error_mean'] = df.linear_gen_error.apply(np.mean)\n",
    "df['attention_gen_error_mean'] = df.attention_gen_error.apply(np.mean)\n",
    "df['attention_train_error_mean'] = df.attention_train_error.apply(np.mean)\n",
    "df['attention_mag_mean'] = df.attention_magnetization.apply(np.mean)\n",
    "df['attention_theta_mean'] = df.attention_theta.apply(np.mean)\n",
    "df_pos_init = df\n",
    "df_pos_init.informed = False\n",
    "\n",
    "df = concat_dataframes([load_pickle(f) for f in (empirics_dir / '01_paper_toy_example_phase_diagram').glob('*.pkl')]) # replace with 01_paper_toy_example_phase_diagram\n",
    "df['linear_gen_error_mean'] = df.linear_gen_error.apply(np.mean)\n",
    "df['attention_gen_error_mean'] = df.attention_gen_error.apply(np.mean)\n",
    "df['attention_train_error_mean'] = df.attention_train_error.apply(np.mean)\n",
    "df['attention_mag_mean'] = df.attention_magnetization.apply(np.mean)\n",
    "df['attention_theta_mean'] = df.attention_theta.apply(np.mean)\n",
    "df = df[df.informed]\n",
    "\n",
    "\n",
    "df = pd.concat([df, df_pos_init])\n",
    "df = group(df).head(24) # retain the first 24 rows for each group\n",
    "\n",
    "# assert that every group actually has 24 rows\n",
    "print(group(df).size())\n",
    "assert all(group(df).size() == 24)\n",
    "\n",
    "df_orig = df.copy()\n",
    "df = df_orig.groupby(['d','L','r','delta','sigma','omega','attention_lmbda','linear_lmbda','alpha','informed','optim']).mean(numeric_only=True).reset_index()\n",
    "df_std = df_orig.groupby(['d','L','r','delta','sigma','omega','attention_lmbda','linear_lmbda','alpha','informed','optim']).std(numeric_only=True).reset_index()\n",
    "\n",
    "\n",
    "\n",
    "# save both df and df_std\n",
    "df.to_csv('empirics/results_mean_standardA.csv', index=False)\n",
    "df_std.to_csv('empirics/results_std_standardA.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_samples = df_orig.groupby(['d','L','r','delta','sigma','omega','attention_lmbda','linear_lmbda','alpha']).get_group((1000, 2, 1, 0.4, 0.5, 0.3, 0.01, 0.0001, 2.0))\n",
    "\n",
    "df_samples.to_csv('empirics/results_samples_standardA.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d     L  r  delta  sigma  omega  attention_lmbda  linear_lmbda  alpha     informed  optim\n",
      "1000  2  1  0.4    0.5    0.02   0.01             0.0001        0.010000  False     GD       24\n",
      "                                                                          True      GD       24\n",
      "                                                                0.092917  False     GD       24\n",
      "                                                                          True      GD       24\n",
      "                                                                0.175833  False     GD       24\n",
      "                                                                                             ..\n",
      "                          0.70   0.01             0.0001        1.834167  True      GD       24\n",
      "                                                                1.917083  False     GD       24\n",
      "                                                                          True      GD       24\n",
      "                                                                2.000000  False     GD       24\n",
      "                                                                          True      GD       24\n",
      "Length: 1350, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "df = concat_dataframes([load_pickle(f) for f in (empirics_dir / '02_paper_toy_example_phase_diagram_otherA').glob('*.pkl')]) # replace with 01_paper_toy_example_phase_diagram\n",
    "df['linear_gen_error_mean'] = df.linear_gen_error.apply(np.mean)\n",
    "df['attention_gen_error_mean'] = df.attention_gen_error.apply(np.mean)\n",
    "df['attention_train_error_mean'] = df.attention_train_error.apply(np.mean)\n",
    "df['attention_mag_mean'] = df.attention_magnetization.apply(np.mean)\n",
    "df['attention_theta_mean'] = df.attention_theta.apply(np.mean)\n",
    "df_pos_init = df\n",
    "\n",
    "df = group(df).head(24) # retain the first 24 rows for each group\n",
    "\n",
    "# assert that every group actually has 24 rows\n",
    "print(group(df).size())\n",
    "assert all(group(df).size() == 24)\n",
    "\n",
    "df_orig = df.copy()\n",
    "df = df_orig.groupby(['d','L','r','delta','sigma','omega','attention_lmbda','linear_lmbda','alpha','informed','optim']).mean(numeric_only=True).reset_index()\n",
    "df_std = df_orig.groupby(['d','L','r','delta','sigma','omega','attention_lmbda','linear_lmbda','alpha','informed','optim']).std(numeric_only=True).reset_index()\n",
    "\n",
    "# save both df and df_std\n",
    "df.to_csv('empirics/results_mean_otherA.csv', index=False)\n",
    "df_std.to_csv('empirics/results_std_otherA.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = concat_dataframes([load_pickle(f) for f in (empirics_dir / '05_paper_toy_example_r=2_large_alpha').glob('*.pkl')]+[load_pickle(f) for f in (empirics_dir / '03_paper_toy_example_r=2').glob('*.pkl')]) # replace with 01_paper_toy_example_phase_diagram\n",
    "df['attention_mag_max'] = df.magli.apply(lambda x: x[-1][0].max())\n",
    "df['attention_thet_max'] = df.thetali.apply(lambda x: x[-1][0].max())\n",
    "df['attention_mag_min'] = df.magli.apply(lambda x: x[-1][0].min())\n",
    "df['attention_thet_min'] = df.thetali.apply(lambda x: x[-1][0].min())\n",
    "\n",
    "\n",
    "df['linear_gen_error_mean'] = df.linear_gen_error.apply(np.mean)\n",
    "df['attention_gen_error_mean'] = df.attention_gen_error.apply(np.mean)\n",
    "df['attention_train_error_mean'] = df.attention_train_error.apply(np.mean)\n",
    "df_orig = df.copy()\n",
    "\n",
    "\n",
    "df = group(df).head(10) # retain the first 10 rows for each group\n",
    "\n",
    "df_orig = df.copy()\n",
    "df = df_orig.groupby(['d','L','r','delta','sigma','omega','attention_lmbda','linear_lmbda','alpha','informed','optim']).mean(numeric_only=True).reset_index()\n",
    "df_std = df_orig.groupby(['d','L','r','delta','sigma','omega','attention_lmbda','linear_lmbda','alpha','informed','optim']).std(numeric_only=True).reset_index()\n",
    "\n",
    "\n",
    "# save both df and df_std\n",
    "df.to_csv('empirics/results_mean_r=2.csv', index=False)\n",
    "df_std.to_csv('empirics/results_std_r=2.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(100, 2, 1, 0.4, 0.5, 0.3, 0.01, 0.0001, 2.0, False, 'adam')\n"
     ]
    }
   ],
   "source": [
    "df = concat_dataframes([load_pickle(f) for f in ( Path('raw/mixed_teacher_softmax_adam_d=100')).glob('*.pkl')]) # replace with 01_paper_toy_example_phase_diagram\n",
    "df['linear_gen_error_mean'] = df.linear_gen_error.apply(np.mean)\n",
    "df['attention_gen_error_mean'] = df.attention_gen_error.apply(np.mean)\n",
    "df['attention_train_error_mean'] = df.attention_train_error.apply(np.mean)\n",
    "df['attention_mag_mean'] = df.attention_magnetization.apply(np.mean)\n",
    "df['attention_theta_mean'] = df.attention_theta.apply(np.mean)\n",
    "df = df[~df.informed]\n",
    "\n",
    "for t, g in group(df):\n",
    "    print(t)\n",
    "    \n",
    "df.to_csv('empirics/results_adam.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                             N_iter  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim           \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD         80   \n",
      "                                                             True     GD         80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD         80   \n",
      "                                                             True     GD         80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD         80   \n",
      "                                                             True     GD         80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD         80   \n",
      "                                                             True     GD         80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD         80   \n",
      "                                                             True     GD         80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD         80   \n",
      "                                                             True     GD         80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD         80   \n",
      "                                                             True     GD         80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD         80   \n",
      "                                                             True     GD         80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD         80   \n",
      "                                                             True     GD         80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD         79   \n",
      "                                                             True     GD         79   \n",
      "\n",
      "                                                                             attention_gen_error  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                        \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                      80   \n",
      "                                                             True     GD                      80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                      80   \n",
      "                                                             True     GD                      80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                      80   \n",
      "                                                             True     GD                      80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                      80   \n",
      "                                                             True     GD                      80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                      80   \n",
      "                                                             True     GD                      80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                      80   \n",
      "                                                             True     GD                      80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                      80   \n",
      "                                                             True     GD                      80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                      80   \n",
      "                                                             True     GD                      80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                      80   \n",
      "                                                             True     GD                      80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                      79   \n",
      "                                                             True     GD                      79   \n",
      "\n",
      "                                                                             attention_theta  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                    \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                  80   \n",
      "                                                             True     GD                  80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                  80   \n",
      "                                                             True     GD                  80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                  80   \n",
      "                                                             True     GD                  80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                  80   \n",
      "                                                             True     GD                  80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                  80   \n",
      "                                                             True     GD                  80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                  80   \n",
      "                                                             True     GD                  80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                  80   \n",
      "                                                             True     GD                  80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                  80   \n",
      "                                                             True     GD                  80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                  80   \n",
      "                                                             True     GD                  80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                  79   \n",
      "                                                             True     GD                  79   \n",
      "\n",
      "                                                                             attention_magnetization  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                            \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                          80   \n",
      "                                                             True     GD                          80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                          80   \n",
      "                                                             True     GD                          80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                          80   \n",
      "                                                             True     GD                          80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                          80   \n",
      "                                                             True     GD                          80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                          80   \n",
      "                                                             True     GD                          80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                          80   \n",
      "                                                             True     GD                          80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                          80   \n",
      "                                                             True     GD                          80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                          80   \n",
      "                                                             True     GD                          80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                          80   \n",
      "                                                             True     GD                          80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                          79   \n",
      "                                                             True     GD                          79   \n",
      "\n",
      "                                                                             linear_gen_error  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                     \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                   80   \n",
      "                                                             True     GD                   80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                   80   \n",
      "                                                             True     GD                   80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                   80   \n",
      "                                                             True     GD                   80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                   80   \n",
      "                                                             True     GD                   80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                   80   \n",
      "                                                             True     GD                   80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                   80   \n",
      "                                                             True     GD                   80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                   80   \n",
      "                                                             True     GD                   80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                   80   \n",
      "                                                             True     GD                   80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                   80   \n",
      "                                                             True     GD                   80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                   79   \n",
      "                                                             True     GD                   79   \n",
      "\n",
      "                                                                             attention_train_error  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                          \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        79   \n",
      "                                                             True     GD                        79   \n",
      "\n",
      "                                                                             attention_train_error_curves  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                                 \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                               80   \n",
      "                                                             True     GD                               80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                               80   \n",
      "                                                             True     GD                               80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                               80   \n",
      "                                                             True     GD                               80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                               80   \n",
      "                                                             True     GD                               80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                               80   \n",
      "                                                             True     GD                               80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                               80   \n",
      "                                                             True     GD                               80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                               80   \n",
      "                                                             True     GD                               80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                               80   \n",
      "                                                             True     GD                               80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                               80   \n",
      "                                                             True     GD                               80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                               79   \n",
      "                                                             True     GD                               79   \n",
      "\n",
      "                                                                             informed_position  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                      \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                    80   \n",
      "                                                             True     GD                    80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                    80   \n",
      "                                                             True     GD                    80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                    80   \n",
      "                                                             True     GD                    80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                    80   \n",
      "                                                             True     GD                    80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                    80   \n",
      "                                                             True     GD                    80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                    80   \n",
      "                                                             True     GD                    80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                    80   \n",
      "                                                             True     GD                    80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                    80   \n",
      "                                                             True     GD                    80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                    80   \n",
      "                                                             True     GD                    80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                    79   \n",
      "                                                             True     GD                    79   \n",
      "\n",
      "                                                                             which_A  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim            \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD          80   \n",
      "                                                             True     GD          80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD          80   \n",
      "                                                             True     GD          80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD          80   \n",
      "                                                             True     GD          80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD          80   \n",
      "                                                             True     GD          80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD          80   \n",
      "                                                             True     GD          80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD          80   \n",
      "                                                             True     GD          80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD          80   \n",
      "                                                             True     GD          80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD          80   \n",
      "                                                             True     GD          80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD          80   \n",
      "                                                             True     GD          80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD          79   \n",
      "                                                             True     GD          79   \n",
      "\n",
      "                                                                             linear_gen_error_mean  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                          \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        80   \n",
      "                                                             True     GD                        80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                        79   \n",
      "                                                             True     GD                        79   \n",
      "\n",
      "                                                                             attention_gen_error_mean  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                             \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                           80   \n",
      "                                                             True     GD                           80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                           80   \n",
      "                                                             True     GD                           80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                           80   \n",
      "                                                             True     GD                           80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                           80   \n",
      "                                                             True     GD                           80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                           80   \n",
      "                                                             True     GD                           80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                           80   \n",
      "                                                             True     GD                           80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                           80   \n",
      "                                                             True     GD                           80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                           80   \n",
      "                                                             True     GD                           80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                           80   \n",
      "                                                             True     GD                           80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                           79   \n",
      "                                                             True     GD                           79   \n",
      "\n",
      "                                                                             attention_train_error_mean  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                               \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                             80   \n",
      "                                                             True     GD                             80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                             80   \n",
      "                                                             True     GD                             80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                             80   \n",
      "                                                             True     GD                             80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                             80   \n",
      "                                                             True     GD                             80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                             80   \n",
      "                                                             True     GD                             80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                             80   \n",
      "                                                             True     GD                             80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                             80   \n",
      "                                                             True     GD                             80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                             80   \n",
      "                                                             True     GD                             80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                             80   \n",
      "                                                             True     GD                             80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                             79   \n",
      "                                                             True     GD                             79   \n",
      "\n",
      "                                                                             attention_mag_mean  \\\n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                       \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                     80   \n",
      "                                                             True     GD                     80   \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                     80   \n",
      "                                                             True     GD                     80   \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                     80   \n",
      "                                                             True     GD                     80   \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                     80   \n",
      "                                                             True     GD                     80   \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                     80   \n",
      "                                                             True     GD                     80   \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                     80   \n",
      "                                                             True     GD                     80   \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                     80   \n",
      "                                                             True     GD                     80   \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                     80   \n",
      "                                                             True     GD                     80   \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                     80   \n",
      "                                                             True     GD                     80   \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                     79   \n",
      "                                                             True     GD                     79   \n",
      "\n",
      "                                                                             attention_theta_mean  \n",
      "d   L r delta sigma omega attention_lmbda linear_lmbda alpha informed optim                        \n",
      "10  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                       80  \n",
      "                                                             True     GD                       80  \n",
      "15  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                       80  \n",
      "                                                             True     GD                       80  \n",
      "23  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                       80  \n",
      "                                                             True     GD                       80  \n",
      "36  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                       80  \n",
      "                                                             True     GD                       80  \n",
      "56  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                       80  \n",
      "                                                             True     GD                       80  \n",
      "87  2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                       80  \n",
      "                                                             True     GD                       80  \n",
      "135 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                       80  \n",
      "                                                             True     GD                       80  \n",
      "209 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                       80  \n",
      "                                                             True     GD                       80  \n",
      "323 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                       80  \n",
      "                                                             True     GD                       80  \n",
      "500 2 1 0.4   0.5   0.3   0.01            0.0001       1.5   False    GD                       79  \n",
      "                                                             True     GD                       79  \n"
     ]
    }
   ],
   "source": [
    "df = concat_dataframes([load_pickle(f) for f in ( Path('raw/empirics/06_paper_toy_example_scaling')).glob('*.pkl')]) # replace with 01_paper_toy_example_phase_diagram\n",
    "df['linear_gen_error_mean'] = df.linear_gen_error.apply(np.mean)\n",
    "df['attention_gen_error_mean'] = df.attention_gen_error.apply(np.mean)\n",
    "df['attention_train_error_mean'] = df.attention_train_error.apply(np.mean)\n",
    "df['attention_mag_mean'] = df.attention_magnetization.apply(np.mean)\n",
    "df['attention_theta_mean'] = df.attention_theta.apply(np.mean)\n",
    "print(group(df).count())\n",
    "\n",
    "max_samples = 70\n",
    "df = group(df).head(max_samples)\n",
    "\n",
    "df_orig = df.copy()\n",
    "df = df_orig.groupby(['d','L','r','delta','sigma','omega','attention_lmbda','linear_lmbda','alpha','informed','optim']).mean(numeric_only=True).reset_index()\n",
    "df_std = df_orig.groupby(['d','L','r','delta','sigma','omega','attention_lmbda','linear_lmbda','alpha','informed','optim']).std(numeric_only=True).reset_index() \n",
    "\n",
    "df_orig.to_csv('empirics/results_scaling_orig.csv', index=False)\n",
    "df.to_csv('empirics/results_scaling.csv', index=False)\n",
    "df_std.to_csv('empirics/results_std_scaling.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
