{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import glob \n",
    "from PIL import Image\n",
    "import os\n",
    "from shutil import copy2\n",
    "from glob import glob\n",
    "import argparse\n",
    "import json\n",
    "import pandas as pd\n",
    "import scipy as sp\n",
    "import scipy.io.wavfile as wav\n",
    "import librosa\n",
    "import IPython.display as ipd\n",
    "import collections\n",
    "import random\n",
    "import pickle\n",
    "import tensorflow as tf\n",
    "from tqdm.notebook import tqdm\n",
    "from tensorflow.keras.layers import Layer, Input, InputLayer, ReLU, Flatten, Dense, Conv2D, BatchNormalization, LeakyReLU, Dropout, Activation, MaxPool2D, Concatenate,Lambda,Reshape,Conv2DTranspose\n",
    "from tensorflow.keras.models import Model, Sequential, clone_model, load_model\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras import activations, optimizers\n",
    "from tensorflow.keras.losses import mean_squared_error\n",
    "from tensorflow.keras.datasets import mnist, fashion_mnist\n",
    "from tensorflow.keras.callbacks import Callback\n",
    "from tensorflow.python.framework.ops import disable_eager_execution\n",
    "from sklearn.decomposition import FastICA\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed=0):\n",
    "    tf.random.set_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"../datasets/UrbanSound8K/metadata/UrbanSound8K.csv\", index_col=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes=sorted(set(df[\"class\"].values.tolist()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_paths=sorted(glob(\"../datasets/UrbanSound8K/audio/*/*.wav\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths_list=[]\n",
    "for i in range(len(classes)):\n",
    "    paths=[]\n",
    "    for path in all_paths:\n",
    "        if df.at[path.split(\"/\")[-1], \"class\"]==classes[i]:\n",
    "            paths.append(path)\n",
    "    paths_list.append(paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Preprocess(object):\n",
    "    '''\n",
    "    Parameters\n",
    "    ----------\n",
    "    window_width_msec : float > 0\n",
    "        msecで窓幅を指定\n",
    "    window_slide_msec : float > 0\n",
    "        msecで窓のシフト長を指定\n",
    "        \n",
    "    以下はライブラリlibrosa内で使用するパラメタ\n",
    "    n_mel: int > 0 [scalar]\n",
    "        number of mel spectrograms to return\n",
    "    n_mfcc: int > 0 [scalar]\n",
    "        number of MFCCs to return\n",
    "    coef : 0 <= float <= 1\n",
    "        preemphasis filter with preemph as coefficient.\n",
    "        0 is no filter. Default is 0.97.\n",
    "    db_threshold : number > 0\n",
    "        The threshold (in decibels) below reference \n",
    "        to consider as silence\n",
    "    dct_type : None, or {1, 2, 3}\n",
    "        Discrete cosine transform (DCT) type.\n",
    "        By default, DCT type-2 is used.\n",
    "    norm_mfcc : None or 'ortho'\n",
    "        If `dct_type` is `2 or 3`, setting `norm='ortho'` \n",
    "        uses an ortho-normal DCT basis.\n",
    "        Normalization is not supported for `dct_type=1`.\n",
    "    power : float > 0 [scalar]\n",
    "        Exponent for the magnitude melspectrogram.\n",
    "        e.g., 1 for energy, 2 for power, etc.\n",
    "    norm_mel : {None, 1, np.inf} [scalar]\n",
    "        if 1, divide the triangular mel weights by the\n",
    "        width of the mel band (area normalization).\n",
    "        Otherwise, leave all the triangles aiming for\n",
    "        a peak value of 1.0\n",
    "    fmin      : float >= 0 [scalar]\n",
    "        lowest frequency (in Hz)\n",
    "    fmax      : float >= 0 [scalar]\n",
    "        highest frequency (in Hz).\n",
    "        If `None`, use `fmax = sr / 2.0`\n",
    "    htk       : bool [scalar]\n",
    "        use HTK formula instead of Slaney\n",
    "    '''\n",
    "    def __init__(self, \n",
    "                 window_width_msec=51.2, \n",
    "                 window_slide_msec=25.6,\n",
    "                 n_mels=128,\n",
    "                 n_mfcc=12,\n",
    "                 coef=0.97,\n",
    "                 db_threshold=30,\n",
    "                 dct_type=2,\n",
    "                 norm_mfcc='ortho',\n",
    "                 power=2.0,\n",
    "                 norm_mel=1,\n",
    "                 fmin=0.0,\n",
    "                 fmax=None,\n",
    "                 htk=False):\n",
    "        self.window_width_msec = window_width_msec\n",
    "        self.window_slide_msec = window_slide_msec\n",
    "        self.n_mels = n_mels\n",
    "        self.n_mfcc = n_mfcc\n",
    "        self.coef = coef\n",
    "        self.db_threshold = db_threshold\n",
    "        self.dct_type = dct_type\n",
    "        self.norm_mfcc = norm_mfcc\n",
    "        self.power = power\n",
    "        self.norm_mel = norm_mel\n",
    "        self.fmin = fmin\n",
    "        self.fmax = fmax\n",
    "        self.htk = htk\n",
    "    \n",
    "    def preemphasis(self, y, coef):\n",
    "        '''時間波形に対し、プリエンファシス処理をする。\n",
    "        Parameters\n",
    "        ----------\n",
    "        y : np.ndarray [shape=(?,)]\n",
    "            1次元の信号(時間波形)\n",
    "        coef : float\n",
    "            プリエンファシス係数\n",
    "        Returns\n",
    "        -------\n",
    "        preemph_y : np.ndarray [shape=(?,)]\n",
    "            プリエンファシスされた信号\n",
    "        '''\n",
    "        return np.append(y[0], y[1:] - coef * y[:-1])\n",
    "    \n",
    "    def calc_window(self, sr):\n",
    "        '''サンプリングレートから、指定時間に対応するウインドウ幅、シフトを計算する。\n",
    "        Parameters\n",
    "        ----------\n",
    "        sr : int > 0\n",
    "            サンプリングレート\n",
    "        Returns\n",
    "        -------\n",
    "        n_fft : int > 0\n",
    "            fftウインドウ幅のポイント数\n",
    "        hop_length : int > 0\n",
    "            ウインドウシフトのポイント数\n",
    "        '''\n",
    "        n_fft = int(self.window_width_msec * sr / 1000)\n",
    "        hop_length = int(self.window_slide_msec * sr / 1000)\n",
    "        return n_fft, hop_length\n",
    "\n",
    "    def energy(self, wavfile, trimmed=False):\n",
    "        '''ファイルを読み込み、設定したウインドウ幅にてroot-mean-square energyを計算する。\n",
    "        Parameters\n",
    "        ----------\n",
    "        wavfile : str\n",
    "            信号(時間波形)ファイルのファイルパス\n",
    "        trimmed : bool\n",
    "            True時にyの無音区間除去後にエネルギーを計算\n",
    "        Returns\n",
    "        -------\n",
    "        energy : np.ndarray [shape=(?,)]\n",
    "            エネルギー　\n",
    "        '''\n",
    "        rate, y = wav.read(wavfile)\n",
    "        energy = self.energy_from_signal(y, rate, trimmed=trimmed)\n",
    "\n",
    "        return energy\n",
    "\n",
    "    def energy_from_signal(self, y, rate, trimmed=False):\n",
    "        '''信号・サンプリングレートを読み込み、設定したウインドウ幅にてroot-mean-square energyを計算する。\n",
    "        Parameters\n",
    "        ----------\n",
    "        y : np.ndarray [shape=(?,)]\n",
    "            1次元の信号(時間波形)\n",
    "        rate : int > 0\n",
    "            サンプリングレート\n",
    "        trimmed: bool\n",
    "            True時にyの無音区間除去後にエネルギーを計算\n",
    "        Returns\n",
    "        -------\n",
    "        energy : np.ndarray [shape=(?,)]\n",
    "            エネルギー　\n",
    "        '''\n",
    "        n_fft, hop_length = self.calc_window(rate)\n",
    "        \n",
    "        if trimmed:\n",
    "            # librosaのライブラリを用いて各ファイルの前後の無音区間を除去\n",
    "            y, _ = librosa.effects.trim(y.astype(np.float),\n",
    "                                        top_db=self.db_threshold,\n",
    "                                        frame_length=n_fft,\n",
    "                                        hop_length=hop_length)\n",
    "        \n",
    "        energy = librosa.feature.rms(y=y.astype(np.float),\n",
    "                                     frame_length=n_fft,\n",
    "                                     hop_length=hop_length)\n",
    "        return energy.squeeze()\n",
    "    \n",
    "    def melspectrogram(self, sr, y, preemphasis=False, log=True):\n",
    "        '''信号・サンプリングレートを読み込み、設定値を用いて、メルスケールのスペクトルを算出。\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        sr : int > 0\n",
    "            サンプリングレート\n",
    "        y : np.ndarray [shape=(?,)]\n",
    "            1次元の信号(時間波形)\n",
    "        log : bool\n",
    "            True時に単位をデシベル(ログスケール)とする\n",
    "        Returns\n",
    "        -------\n",
    "        feat : np.ndarray [shape=(self.n_mels, n)]\n",
    "            指定したフィルタ数に対応するスペクトル値\n",
    "        '''\n",
    "        n_fft, hop_length = self.calc_window(sr)\n",
    "        # librosaのライブラリを用いて各ファイルの前後の無音区間を除去\n",
    "        trimmed_y, _ = librosa.effects.trim(y.astype(np.float),\n",
    "                                    top_db=self.db_threshold,\n",
    "                                    frame_length=n_fft,\n",
    "                                    hop_length=hop_length)\n",
    "#         trimmed_y=y # 無音区間は除去しない\n",
    "\n",
    "        # プリエンファシス処理\n",
    "        if preemphasis:\n",
    "            trimmed_y = self.preemphasis(trimmed_y, self.coef)\n",
    "        # メルスペクトルの計算\n",
    "        feat = librosa.feature.melspectrogram(y=trimmed_y,\n",
    "                                              sr=sr,\n",
    "                                              n_fft=n_fft,\n",
    "                                              hop_length=hop_length,\n",
    "                                              power=self.power,\n",
    "                                              n_mels=self.n_mels,\n",
    "                                              fmin=self.fmin,\n",
    "                                              fmax=self.fmax,\n",
    "                                              htk=self.htk,\n",
    "                                              norm=self.norm_mel)\n",
    "        if log:\n",
    "            # デシベルに変換して出力\n",
    "            return librosa.power_to_db(feat)\n",
    "        else:\n",
    "            return feat\n",
    "    \n",
    "    def mfcc(self, sr, y, deltas=[]):\n",
    "        '''信号・サンプリングレートを読み込み、設定値を用いて、メルフィルタケプストラム係数(MFCC)を算出。\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        sr : int > 0\n",
    "            サンプリングレート\n",
    "        y : np.ndarray [shape=(?,)]\n",
    "            1次元の時間波形(信号)\n",
    "        deltas : list\n",
    "            リストにて指定した次数のmfccのdeltaを計算\n",
    "        Returns\n",
    "        -------\n",
    "        feat : np.ndarray [shape=(self.n_mfcc, n)]\n",
    "            指定したフィルタ数に対応するmfcc\n",
    "        '''\n",
    "        S = self.melspectrogram(sr, y, log=True)\n",
    "        # メルスペクトルからmfccを計算\n",
    "        feat = sp.fftpack.dct(S,\n",
    "                              axis=0,\n",
    "                              type=self.dct_type,\n",
    "                              norm=self.norm_mfcc)[:self.n_mfcc]\n",
    "        # deltasが空でない場合に指定した次数のdeltaを計算し、特徴量に結合する\n",
    "        if len(deltas):\n",
    "            feat_zero = feat.copy()\n",
    "            for delta_order in deltas:\n",
    "                feat = np.append(feat,\n",
    "                                 librosa.feature.delta(feat_zero,\n",
    "                                                       order=delta_order),\n",
    "                                 axis=0)\n",
    "        return feat\n",
    "    \n",
    "    def transform(self, wavfile, output_type):\n",
    "        '''信号ファイルを読み込み、設定値を用いて、特徴量を算出。\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        wavfile : str\n",
    "            信号(時間波形)ファイルのファイルパス\n",
    "        output_type : str\n",
    "            出力形式をstrにて['mel', 'mfcc', 'mfcc_delta']から指定\n",
    "            \n",
    "        Returns\n",
    "        -------\n",
    "        feat : np.ndarray [shape=(?, n_feature)] \n",
    "            指定した出力形式の特徴量\n",
    "        '''\n",
    "        rate, y = wav.read(wavfile)\n",
    "        \n",
    "        if output_type == 'mel':\n",
    "            feat = self.melspectrogram(rate, y).T\n",
    "        elif output_type == 'mfcc':\n",
    "            feat = self.mfcc(rate, y).T\n",
    "        elif output_type == 'mfcc_delta':\n",
    "            feat = self.mfcc(rate, y, deltas=[1, 2]).T\n",
    "        return feat\n",
    "    \n",
    "    def transform_from_signal(self, y, rate, output_type, preemphasis=False):\n",
    "        '''信号・サンプリングレートを読み込み、設定値を用いて、特徴量を算出。\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        wavfile : str\n",
    "            信号(時間波形)ファイルのファイルパス\n",
    "        output_type : str\n",
    "            出力形式をstrにて ['mel', 'mfcc', 'mfcc_delta'] の中から指定\n",
    "            \n",
    "        Returns\n",
    "        -------\n",
    "        feat : np.ndarray [shape=(?, n_feature)] \n",
    "            指定した種類の特徴量\n",
    "        '''\n",
    "        if output_type == 'mel':\n",
    "            feat = self.melspectrogram(rate, y, preemphasis=preemphasis)\n",
    "        elif output_type == 'mfcc':\n",
    "            feat = self.mfcc(rate, y).T\n",
    "        elif output_type == 'mfcc_delta':\n",
    "            feat = self.mfcc(rate, y, deltas=[1, 2]).T\n",
    "        return feat\n",
    "    \n",
    "    def inversed_mel(self, spectrum, rate, n_fft=1024):\n",
    "        _, hop_length = self.calc_window(rate)\n",
    "        aud = librosa.feature.inverse.mel_to_audio(librosa.db_to_power(spectrum),\n",
    "                                                   n_fft=n_fft,\n",
    "                                                   hop_length=hop_length)\n",
    "        return aud"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "904680e4a8b94799bb979ce3880d20c7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/itrn21_takatsuki/.local/lib/python3.6/site-packages/librosa/util/decorators.py:88: UserWarning: n_fft=2205 is too small for input signal of length=1103\n",
      "  return f(*args, **kwargs)\n",
      "/home/itrn21_takatsuki/.local/lib/python3.6/site-packages/librosa/util/decorators.py:88: UserWarning: n_fft=2205 is too small for input signal of length=1323\n",
      "  return f(*args, **kwargs)\n",
      "/home/itrn21_takatsuki/.local/lib/python3.6/site-packages/librosa/util/decorators.py:88: UserWarning: n_fft=2205 is too small for input signal of length=1523\n",
      "  return f(*args, **kwargs)\n",
      "/home/itrn21_takatsuki/.local/lib/python3.6/site-packages/librosa/util/decorators.py:88: UserWarning: n_fft=2205 is too small for input signal of length=2137\n",
      "  return f(*args, **kwargs)\n"
     ]
    }
   ],
   "source": [
    "preprocess_params={\n",
    "\"window_width_msec\": 100,\n",
    "\"window_slide_msec\": 5, \n",
    "\"n_mels\": 16, \n",
    "\"n_mfcc\": 12,\n",
    "\"coef\": 0.97,\n",
    "\"db_threshold\": 30,\n",
    "\"dct_type\": 2,\n",
    "\"norm_mfcc\": \"ortho\",\n",
    "\"power\": 2.0,\n",
    "\"norm_mel\": 1,\n",
    "\"fmin\": 0.0,\n",
    "\"fmax\": None,\n",
    "\"htk\": False\n",
    "}\n",
    "prep = Preprocess(**preprocess_params)\n",
    "feats_list=[]\n",
    "# sample_rates=[]\n",
    "for i in tqdm(range(len(classes))):\n",
    "    paths=paths_list[i]\n",
    "    feats=[]\n",
    "    for j in range(len(paths)):\n",
    "        waveform, sample_rate = librosa.load(paths[j])\n",
    "        feat = prep.transform_from_signal(waveform, sample_rate, output_type=\"mel\")\n",
    "        normalized_feat=(feat-np.min(feat))/(np.max(feat)-np.min(feat))\n",
    "        feats.append(normalized_feat)\n",
    "#         sample_rates.append(sample_rate)\n",
    "    feats_list.append(feats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# size of trainig dataset\n",
    "dataset_size=1000000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# source_size=image_size\n",
    "target_size=16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "patches=np.zeros((len(classes), dataset_size, target_size, target_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d1b19af131af4a929fef751d32e39862",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for i in tqdm(range(len(classes))):\n",
    "    set_seed(0)\n",
    "    feats=feats_list[i]\n",
    "    sizes=np.zeros(len(feats))\n",
    "    \n",
    "    # get the possiblity of using each image\n",
    "    for j in range(len(feats)):\n",
    "        h,w=feats[j].shape[:2]\n",
    "        if h < target_size or w < target_size:\n",
    "            sizes[j]=0\n",
    "        else:\n",
    "            sizes[j]=h*w\n",
    "    indeces=np.sort(np.random.choice(np.arange(len(feats)), size=dataset_size, p=sizes/np.sum(sizes)))\n",
    "    \n",
    "    # random sample patches from datasets\n",
    "    index=None\n",
    "    for k in range(indeces.size):\n",
    "        if indeces[k]!=index:\n",
    "            index=indeces[k]\n",
    "            image=feats[index]\n",
    "            source_size=image.shape[:2]\n",
    "        xloc=np.random.randint(0, source_size[1]-target_size+1)\n",
    "        patches[i, k]=image[:, xloc:xloc+target_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_dir=\"../datasets/metanetwork/16x16_patches_from_urbansound8k\"\n",
    "os.makedirs(dataset_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "232121640f2c4fd2bf8780956f381d31",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for i in tqdm(range(len(classes))):\n",
    "    np.savez_compressed(f\"{dataset_dir}/{classes[i]}\", data=patches[i])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir=\"../output/metanetwork/preprocess_and_fit_urbansound8k\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "MLP_z_size=16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define model architecture\n",
    "MLP=Sequential()\n",
    "MLP.add(InputLayer(input_shape=(target_size**2)))\n",
    "MLP.add(Dense(MLP_z_size, use_bias=False))\n",
    "MLP.add(BatchNormalization())\n",
    "MLP.add(Activation(activations.relu))\n",
    "MLP.add(Dense(target_size**2, use_bias=True))\n",
    "MLP.add(Activation(activations.sigmoid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "no_models=600"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3f93d3dfe8b149f89666d6b214f3955e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/600 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b72fd72c456d4cabba2ad00d5c9e4c4e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/600 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f0dac5f7727c4937a04e5679bfb70bd5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/600 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5eae5afbe6144a69995fe8c9601b3c52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/600 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3f72b7df43984264b450a7885f2f5273",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/600 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# train models\n",
    "for i in range(len(classes)):\n",
    "    data=np.load(f\"{dataset_dir}/{classes[i]}.npz\")[\"data\"].reshape((-1, target_size**2))\n",
    "    os.makedirs(f\"{output_dir}/hists/{classes[i]}\", exist_ok=True)\n",
    "    os.makedirs(f\"{output_dir}/models/{classes[i]}\", exist_ok=True)\n",
    "    for j in tqdm(range(no_models)):\n",
    "        if os.path.exists(f\"{output_dir}/models/{classes[i]}/{str(j).zfill(5)}.h5\"):\n",
    "            continue\n",
    "        set_seed(100000*i+j+200000000)\n",
    "        mlp=clone_model(MLP)\n",
    "        mlp.compile(loss='mse', optimizer=optimizers.Adam())\n",
    "        hist=mlp.fit(data, data, epochs=1, batch_size=64, verbose=False)\n",
    "        with open(f\"{output_dir}/hists/{classes[i]}/{str(j).zfill(5)}\", 'wb') as f:\n",
    "            pickle.dump(hist.history, f)\n",
    "        mlp.save_weights(f\"{output_dir}/models/{classes[i]}/{str(j).zfill(5)}.h5\")\n",
    "        K.clear_session()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
