{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import glob \n",
    "from PIL import Image\n",
    "import os\n",
    "from shutil\n",
    "from imageio import imread, imsave\n",
    "import shutil\n",
    "from lxml import etree\n",
    "from glob import glob\n",
    "import argparse\n",
    "import json\n",
    "import pandas as pd\n",
    "import scipy as sp\n",
    "import scipy.io.wavfile as wav\n",
    "import librosa\n",
    "import IPython.display as ipd\n",
    "import collections\n",
    "import random\n",
    "import pickle\n",
    "import tensorflow as tf\n",
    "from tqdm.notebook import tqdm\n",
    "from tensorflow.keras.layers import Layer, Input, InputLayer, ReLU, Flatten, Dense, Conv2D, BatchNormalization, LeakyReLU, Dropout, Activation, MaxPool2D, Concatenate,Lambda,Reshape,Conv2DTranspose\n",
    "from tensorflow.keras.models import Model, Sequential, clone_model, load_model\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras import activations, optimizers\n",
    "from tensorflow.keras.losses import mean_squared_error\n",
    "from tensorflow.keras.datasets import mnist, fashion_mnist\n",
    "from tensorflow.keras.callbacks import Callback\n",
    "from tensorflow.python.framework.ops import disable_eager_execution\n",
    "from sklearn.decomposition import FastICA\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed=0):\n",
    "    tf.random.set_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# extract the segmented area of objects\n",
    "in_dir = '../datasets/VOCdevkit/VOC2012/'\n",
    "out_dir = '../datasets/VOC_class_bbox/'\n",
    "\n",
    "filepaths = sorted(glob(os.path.join(in_dir, 'Annotations', '*'), recursive=True))\n",
    "img_dir = os.path.join(in_dir, 'JPEGImages')\n",
    "\n",
    "data = []\n",
    "for fp in filepaths:\n",
    "    tree = etree.parse(fp)\n",
    "    root = tree.getroot()\n",
    "    fn = root.find('filename').text\n",
    "    height = int(root.find('size').find('height').text)\n",
    "    width = int(root.find('size').find('width').text)\n",
    "\n",
    "    objects = root.findall('.//object')\n",
    "    for cnt, obj in enumerate(objects):\n",
    "        cls = obj.find('name').text\n",
    "        xmin = int(float(obj.find('bndbox').find('xmin').text))\n",
    "        ymin = int(float(obj.find('bndbox').find('ymin').text))\n",
    "        xmax = int(float(obj.find('bndbox').find('xmax').text))\n",
    "        ymax = int(float(obj.find('bndbox').find('ymax').text))\n",
    "        data.append([fp, fn, height, width, height*width,\n",
    "                     cnt, cls, xmin, xmax, ymin, ymax, (xmax - xmin) * (ymax - ymin)])\n",
    "\n",
    "df = pd.DataFrame(data)\n",
    "df.columns = ['filepath', 'filename', 'height', 'width', 'size',\n",
    "              'cnt', 'cls', 'xmin', 'xmax', 'ymin', 'ymax', 'bbox_size']\n",
    "\n",
    "classes = set(df.cls)\n",
    "for cls in classes:\n",
    "    os.makedirs(os.path.join(out_dir, cls), exist_ok=True)\n",
    "\n",
    "for _, row in df.iterrows():\n",
    "    srcp = os.path.join(img_dir, row.filename)\n",
    "    desp = os.path.join(out_dir, row.cls, row.filename[:-4] + '_' + str(row.cnt).zfill(3) + '.jpg')\n",
    "\n",
    "    img = imread(srcp)\n",
    "    img = img[row.ymin:row.ymax, row.xmin:row.xmax]\n",
    "    imsave(desp, img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dirs=sorted(glob(\"../datasets/VOC_class_bbox/*\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes=[dirs[i].split(\"/\")[-1] for i in range(len(dirs))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths_list=[sorted(glob(f\"../../datasets/VOC2012/VOC_class_bbox/{classes[i]}/*.jpg\")) for i in range(len(classes))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# size of trainig dataset\n",
    "dataset_size=1000000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# source_size=image_size\n",
    "target_size=16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "patches=np.zeros((len(classes), dataset_size, target_size, target_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f395acf406f643d58ef704a96c453fc6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for i in tqdm(range(len(classes))):\n",
    "    set_seed(0)\n",
    "    paths=paths_list[i]\n",
    "    sizes=np.zeros(len(paths))\n",
    "    \n",
    "    # get the possiblity of using each image\n",
    "    for j in range(len(paths)):\n",
    "        h,w=np.array(Image.open(paths[j])).shape[:2]\n",
    "        if h < target_size or w < target_size:\n",
    "            sizes[j]=0\n",
    "        else:\n",
    "            sizes[j]=h*w\n",
    "    indeces=np.sort(np.random.choice(np.arange(len(paths)), size=dataset_size, p=sizes/np.sum(sizes)))\n",
    "    \n",
    "    # random sample patches from datasets\n",
    "    index=None\n",
    "    for k in range(indeces.size):\n",
    "        if indeces[k]!=index:\n",
    "            index=indeces[k]\n",
    "            image=np.mean(np.array(Image.open(paths[index])), axis=-1)/255.0\n",
    "            source_size=image.shape[:2]\n",
    "        xloc=np.random.randint(0, source_size[1]-target_size+1)\n",
    "        yloc=np.random.randint(0, source_size[0]-target_size+1)\n",
    "        patches[i, k]=image[yloc:yloc+target_size, xloc:xloc+target_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_dir=\"../datasets/metanetwork/16x16_patches_from_pascal_voc\"\n",
    "os.makedirs(dataset_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "23dac29185154787a9cdff8f04562ddb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/20 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for i in tqdm(range(len(classes))):\n",
    "    np.savez_compressed(f\"{dataset_dir}/{classes[i]}\", data=patches[i])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir=\"../output/metanetwork/preprocess_and_fit_pascalvoc_230920\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "MLP_z_size=16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define model architecture\n",
    "MLP=Sequential()\n",
    "MLP.add(InputLayer(input_shape=(target_size**2)))\n",
    "MLP.add(Dense(MLP_z_size, use_bias=False))\n",
    "MLP.add(BatchNormalization())\n",
    "MLP.add(Activation(activations.relu))\n",
    "MLP.add(Dense(target_size**2, use_bias=True))\n",
    "MLP.add(Activation(activations.sigmoid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "no_models=600"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df4b35fc47254281aeb2cbe6ceb43697",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/600 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# train models\n",
    "for i in range(len(classes)):\n",
    "    data=np.load(f\"{dataset_dir}/{classes[i]}.npz\")[\"data\"].reshape((-1, target_size**2))\n",
    "    os.makedirs(f\"{output_dir}/hists/{classes[i]}\", exist_ok=True)\n",
    "    os.makedirs(f\"{output_dir}/models/{classes[i]}\", exist_ok=True)\n",
    "    for j in tqdm(range(no_models)):\n",
    "        if os.path.exists(f\"{output_dir}/models/{classes[i]}/{str(j).zfill(5)}.h5\"):\n",
    "            continue\n",
    "        set_seed(100000*i+j)\n",
    "        mlp=clone_model(MLP)\n",
    "        mlp.compile(loss='mse', optimizer=optimizers.Adam())\n",
    "        hist=mlp.fit(data, data, epochs=1, batch_size=256, verbose=False)\n",
    "        with open(f\"{output_dir}/hists/{classes[i]}/{str(j).zfill(5)}\", 'wb') as f:\n",
    "            pickle.dump(hist.history, f)\n",
    "        mlp.save_weights(f\"{output_dir}/models/{classes[i]}/{str(j).zfill(5)}.h5\")\n",
    "        K.clear_session()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以下は実行していない"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# extract weights of encoders\n",
    "vision_models=np.empty((no_models, MLP_z_size, target_size, target_size))\n",
    "\n",
    "model=get_mlp(params)\n",
    "for i in range(len(classes)):\n",
    "    for j in range(no_models):\n",
    "        model.load_weights(f\"{output_dir}/models/{classes[i]}/{str(j).zfill(5)}.h5\")\n",
    "        vision_models[i]=(model.get_layer(index=0).get_weights()[0].T).reshape(MLP_z_size, target_size, target_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez_compressed(output_dir+\"/vision_models.npy\", data=vision_models)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
