{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "\n",
    "This notebook trains the baseline method for the structure identification task. It uses a spherical harmonic-based featurization of local particle environments available in [pythia](https://github.com/glotzerlab/pythia)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/paper_authors/env/default_20201216/lib/python3.8/site-packages/tensorflow_addons/utils/ensure_tf_install.py:54: UserWarning: Tensorflow Addons supports using Python ops for all Tensorflow versions above or equal to 2.2.0 and strictly below 2.4.0 (nightly versions are not supported). \n",
      " The versions of TensorFlow you are currently using is 2.4.1 and is not supported. \n",
      "Some things might work, some things might not.\n",
      "If you were to encounter a bug, do not file an issue.\n",
      "If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. \n",
      "You can find the compatibility matrix in TensorFlow Addon's readme:\n",
      "https://github.com/tensorflow/addons\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n"
     ]
    }
   ],
   "source": [
    "from flowws_keras_experimental.InitializeTF import InitializeTF\n",
    "InitializeTF().run(None, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import itertools\n",
    "\n",
    "import flowws\n",
    "from flowws import Argument as Arg\n",
    "import freud\n",
    "import numpy as np\n",
    "import pyriodic\n",
    "import pythia\n",
    "\n",
    "class LocalEnvironmentSphericalHarmonics(flowws.Stage):\n",
    "    ARGS = [\n",
    "        Arg('num_neighbors', '-n', int, 12,\n",
    "           help='Number of nearest neighbors to use'),\n",
    "        Arg('structures', '-s', [str],\n",
    "           help='Name of structures to take'),\n",
    "        Arg('size', None, int, 512,\n",
    "           help='Number of particles to replicate structures up to'),\n",
    "        Arg('noise', None, [float], [1e-2, 5e-2, .1],\n",
    "           help='Noise standard deviation to apply to structures'),\n",
    "        Arg('test_fraction', '-t', float, 0,\n",
    "           help='Fraction of data to hold back as test data'),\n",
    "        Arg('seed', None, int, 13,\n",
    "           help='Random seed to use for shuffling training data'),\n",
    "        Arg('sph_neighbor_max', None, int, 12,\n",
    "           help='Maximum number of neighbors for local environments of spherical harmonics'),\n",
    "        Arg('lmax', '-l', int, 12,\n",
    "           help='Maximum spherical harmonic degree l to consider'),\n",
    "    ]\n",
    "\n",
    "    def run(self, scope, storage):\n",
    "        np.random.seed(self.arguments['seed'])\n",
    "\n",
    "        xs = []\n",
    "        ys = []\n",
    "\n",
    "        name_map = collections.defaultdict(lambda: len(name_map))\n",
    "\n",
    "        structures = list(self.arguments['structures'])\n",
    "        max_types = 0\n",
    "        for name in structures:\n",
    "            for (structure,) in pyriodic.db.query('select structure from unit_cells where name = ?', (name,)):\n",
    "                pass\n",
    "            max_types = max(max_types, len(set(structure.types)))\n",
    "\n",
    "        for name in structures:\n",
    "            for (structure,) in pyriodic.db.query('select structure from unit_cells where name = ?', (name,)):\n",
    "                pass\n",
    "\n",
    "            if name in name_map:\n",
    "                continue\n",
    "\n",
    "            for noise in self.arguments['noise']:\n",
    "                structure = structure.rescale_shortest_distance(1)\n",
    "                structure = structure.replicate_upto(self.arguments['size'])\n",
    "                structure = structure.add_gaussian_noise(noise)\n",
    "\n",
    "                q = freud.locality.AABBQuery(structure.box, structure.positions)\n",
    "                qr = q.query(\n",
    "                    structure.positions, dict(\n",
    "                        num_neighbors=self.arguments['num_neighbors'], exclude_ii=True))\n",
    "                nl = qr.toNeighborList()\n",
    "                index_i = nl.query_point_indices\n",
    "                index_j = nl.point_indices\n",
    "\n",
    "                type_filtered_descriptors = []\n",
    "\n",
    "                for same_t in [True, False]:\n",
    "                    nl_filt = nl.copy()\n",
    "                    if same_t:\n",
    "                        nl_filt.filter(structure.types[index_i] == structure.types[index_j])\n",
    "                    else:\n",
    "                        nl_filt.filter(structure.types[index_i] != structure.types[index_j])\n",
    "\n",
    "                    if len(nl_filt):\n",
    "                        type_filtered_descriptors.append(pythia.spherical_harmonics.abs_neighbor_average(\n",
    "                            structure.box, structure.positions,\n",
    "                            neigh_max=self.arguments['sph_neighbor_max'], lmax=self.arguments['lmax'],\n",
    "                            nlist=nl_filt))\n",
    "                    else:\n",
    "                        type_filtered_descriptors.append(None)\n",
    "\n",
    "                placeholder = [v for v in type_filtered_descriptors if v is not None][0]*0\n",
    "                for i in range(len(type_filtered_descriptors)):\n",
    "                    if type_filtered_descriptors[i] is None:\n",
    "                        type_filtered_descriptors[i] = placeholder\n",
    "                descriptors = np.concatenate(type_filtered_descriptors, axis=-1)\n",
    "\n",
    "                shuf = np.arange(len(descriptors))\n",
    "                np.random.shuffle(shuf)\n",
    "                shuf = shuf[:self.arguments['size']]\n",
    "\n",
    "                xs.append(descriptors[shuf])\n",
    "                ys.append(name_map[name])\n",
    "\n",
    "        ys = np.repeat(ys, [len(v) for v in xs])\n",
    "        xs = np.concatenate(xs, axis=0)\n",
    "\n",
    "        shuf = np.arange(len(xs))\n",
    "        np.random.shuffle(shuf)\n",
    "        N_test = int(self.arguments['test_fraction']*len(shuf))\n",
    "        test_split, train_split = shuf[:N_test], shuf[N_test:]\n",
    "        xs_test = xs[test_split]\n",
    "        ys_test = ys[test_split]\n",
    "        xs = xs[train_split]\n",
    "        ys = ys[train_split]\n",
    "\n",
    "        scope['x_train'] = xs\n",
    "        scope['y_train'] = ys\n",
    "        scope['x_test'] = xs_test\n",
    "        scope['y_test'] = ys_test\n",
    "        scope['num_classes'] = len(name_map)\n",
    "        scope['type_map'] = dict(name_map)\n",
    "        scope['neighborhood_size'] = self.arguments['num_neighbors']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e09a36115b0340628d8c077cf9450b39",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Training'), FloatProgress(value=0.0, layout=Layout(flex='2'), max=128.0), HTML(valu…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 00017: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.\n",
      "\n",
      "Epoch 00025: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.\n",
      "\n",
      "77/77 [==============================] - 0s 1ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "92d63f948a0347f294bc6ae82f7b55b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Training'), FloatProgress(value=0.0, layout=Layout(flex='2'), max=128.0), HTML(valu…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.\n",
      "\n",
      "Epoch 00029: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.\n",
      "\n",
      "77/77 [==============================] - 0s 1ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5780d028904e408db76ea6552c55e1c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Training'), FloatProgress(value=0.0, layout=Layout(flex='2'), max=128.0), HTML(valu…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 00020: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.\n",
      "\n",
      "Epoch 00028: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.\n",
      "\n",
      "77/77 [==============================] - 0s 1ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "492575327a374a7a893b0b4f041f15ea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Training'), FloatProgress(value=0.0, layout=Layout(flex='2'), max=128.0), HTML(valu…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.\n",
      "\n",
      "Epoch 00029: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.\n",
      "\n",
      "77/77 [==============================] - 0s 1ms/step\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2352 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n",
      "2200 particles have too few neighbors\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ee9452b1a4a54fbeafdd739c3c335bcb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value='Training'), FloatProgress(value=0.0, layout=Layout(flex='2'), max=128.0), HTML(valu…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 00017: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.\n",
      "\n",
      "Epoch 00025: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.\n",
      "\n",
      "77/77 [==============================] - 0s 1ms/step\n"
     ]
    }
   ],
   "source": [
    "import flowws\n",
    "from flowws_keras_experimental import *\n",
    "\n",
    "accs = []\n",
    "for seed in range(125, 130):\n",
    "    w = flowws.Workflow(\n",
    "        [\n",
    "            LocalEnvironmentSphericalHarmonics(\n",
    "                noise=[1e-3, 5e-2, 0.1],\n",
    "                structures=[\n",
    "                    \"hP2-Mg\",\n",
    "                    \"cI2-W\",\n",
    "                    \"cF4-Cu\",\n",
    "                    \"cF8-C\",\n",
    "                    \"cF8-SZn\",\n",
    "                    \"cP46-Si\",\n",
    "                    \"cF136-Si\",\n",
    "                    \"cP2-ClCs\",\n",
    "                ],\n",
    "                size=2048,\n",
    "                num_neighbors=12,\n",
    "                sph_neighbor_max=12,\n",
    "                test_fraction=0.2,\n",
    "                seed=seed,\n",
    "            ),\n",
    "            MLP(hidden_widths=[64, 64], activation=\"relu\", batch_norm=True, dropout=0.5),\n",
    "            Classifier(),\n",
    "            Train(epochs=128, validation_split=0.25, reduce_lr=8, early_stopping=20),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    scope = w.run()\n",
    "\n",
    "    prediction = scope['model'].predict(scope['x_test'], batch_size=128, verbose=1)\n",
    "    prediction = np.argmax(prediction, axis=-1)\n",
    "\n",
    "    acc = np.mean(prediction == scope['y_test'])\n",
    "\n",
    "    accs.append(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9614649033570701 0.0010415818844045491\n"
     ]
    }
   ],
   "source": [
    "print(np.mean(accs), np.std(accs)/np.sqrt(len(accs)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
