{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.io as sio\n",
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "datapath = r'..\\data'\n",
    "dataset = 'INSECT'\n",
    "x = sio.loadmat(os.path.join(datapath, dataset, 'res101.mat'))\n",
    "x2 = sio.loadmat(os.path.join(datapath, dataset, 'att_splits.mat')) \n",
    "barcodes=x['nucleotides_aligned'].T\n",
    "species=x['labels']\n",
    "train_loc=x2['trainval_loc']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of training samples and entire data\n",
    "N = len(barcodes)\n",
    "print(len(train_loc[0]), N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reading barcodes and labels into python list\n",
    "bcodes=[]\n",
    "labels=[]\n",
    "for i in range(N):\n",
    "    if len(barcodes[i][0][0])>0:\n",
    "        bcodes.append(barcodes[i][0][0])\n",
    "        labels.append(species[i][0])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Sample barcode: %s' % bcodes[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.preprocessing.text import Tokenizer\n",
    "\n",
    "# Converting base pairs, ACTG, into numbers\n",
    "tokenizer = Tokenizer(char_level=True)\n",
    "tokenizer.fit_on_texts(bcodes)\n",
    "sequence_of_int = tokenizer.texts_to_sequences(bcodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For INSECT dataset, majority of barcodes have a length of 658\n",
    "cnt = 0\n",
    "ll = np.zeros((N, 1))\n",
    "for i in range(N):\n",
    "    ll[i] = len(sequence_of_int[i])\n",
    "    if ll[i]==658:\n",
    "        #print(i)\n",
    "        cnt += 1\n",
    "print(cnt, np.median(ll))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Calculating sequence count for each species\n",
    "cl = len(np.unique(species))\n",
    "seq_len=np.zeros((cl))\n",
    "seq_cnt=np.zeros((cl))\n",
    "for i in range(N):\n",
    "    k=labels[i]-1  # Accounting for Matlab labeling\n",
    "    seq_cnt[k]+=1\n",
    "    #seq_len[k]+=len(sequence_of_int[k])\n",
    "    \n",
    "print('# classes: %d' % cl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Converting all the data into one-hot encoding. Note that this data is not used during training. \n",
    "# We are getting it ready for the prediction time to get the final DNA embeddings after training is done\n",
    "sl = 658\n",
    "allX=np.zeros((N,sl,5))\n",
    "for i in range(N):\n",
    "    Nt=len(sequence_of_int[i])\n",
    "\n",
    "    for j in range(sl):\n",
    "        if(len(sequence_of_int[i])>j):\n",
    "            k=sequence_of_int[i][j]-1\n",
    "            if(k>4):\n",
    "                k=4\n",
    "            allX[i][j][k]=1.0\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the training matrix and labels\n",
    "trainX=np.zeros((N,sl,5))\n",
    "trainY=np.zeros((N,cl))\n",
    "labelY=np.zeros(N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here is where we cap the class sample size to 50 for a balanced training set\n",
    "Nc=-1\n",
    "clas_cnt=np.zeros((cl))\n",
    "for i in range(N):\n",
    "    Nt=len(sequence_of_int[i])\n",
    "    k=labels[i]-1\n",
    "    clas_cnt[k]+=1\n",
    "    itl=i+1\n",
    "    if(seq_cnt[k]>=10 and clas_cnt[k]<=50 and itl in train_loc[0]): # Note that samples from training set are only used \n",
    "        Nc=Nc+1\n",
    "        for j in range(sl):\n",
    "            if(len(sequence_of_int[i])>j):\n",
    "                k=sequence_of_int[i][j]-1\n",
    "                if(k>4):\n",
    "                    k=4\n",
    "                trainX[Nc][j][k]=1.0\n",
    "            \n",
    "        k=labels[i]-1\n",
    "        trainY[Nc][k]=1.0\n",
    "        labelY[Nc]=k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Final training size: %d' % Nc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# selecting the balanced training set\n",
    "trainX=trainX[0:Nc]\n",
    "trainY=trainY[0:Nc]\n",
    "labelY=labelY[0:Nc]\n",
    "print(trainX.shape, trainY.shape, labelY.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To make sure the training data does not include any unseen class nucleotides\n",
    "label_us = species[x2['test_unseen_loc'][0]-1]\n",
    "np.intersect1d(np.unique(labelY), np.unique(label_us)-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cleaning empty classes\n",
    "idx = np.argwhere(np.all(trainY[..., :] == 0, axis=0))\n",
    "trainY = np.delete(trainY, idx, axis=1)\n",
    "\n",
    "print(len(idx), min(np.sum(trainY,axis=0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(trainY.shape)\n",
    "print(labelY.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Expanding the training set shape for CNN \n",
    "trainX=np.expand_dims(trainX, axis=3)\n",
    "allX=np.expand_dims(allX, axis=3)\n",
    "print(trainX.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# Splitting the training set into train and validation\n",
    "X_train, X_test, y_train, y_test = train_test_split(trainX, trainY, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(y_train.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classic CNN Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras import datasets, layers, models, optimizers, callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CNN model architecture\n",
    "model = models.Sequential()\n",
    "model.add(layers.Conv2D(64, (3,3), activation='relu', input_shape=(sl, 5,1),padding=\"SAME\"))\n",
    "model.add(layers.BatchNormalization())\n",
    "model.add(layers.MaxPooling2D((3,1)))\n",
    "model.add(layers.Conv2D(32, (3,3), activation='relu',padding=\"SAME\"))\n",
    "model.add(layers.BatchNormalization())\n",
    "model.add(layers.MaxPooling2D((3,1)))\n",
    "model.add(layers.Conv2D(16, (3,3), activation='relu',padding=\"SAME\"))\n",
    "model.add(layers.BatchNormalization())\n",
    "model.add(layers.MaxPooling2D((3,1)))\n",
    "# model.add(layers.Conv2D(16, (3,3), activation='relu',padding=\"SAME\"))\n",
    "# model.add(layers.BatchNormalization())\n",
    "# model.add(layers.MaxPooling2D((3,1)))\n",
    "model.add(layers.Flatten())\n",
    "model.add(layers.BatchNormalization())\n",
    "model.add(layers.Dense(500, activation='tanh'))\n",
    "model.add(layers.Dense(y_train.shape[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step-decay learning rate scheduler\n",
    "def step_decay(epoch):\n",
    "   initial_lrate = 0.001\n",
    "   drop = 0.5\n",
    "   epochs_drop = 2.0\n",
    "   lrate = initial_lrate * np.power(drop, np.floor((1+epoch)/epochs_drop))\n",
    "   return lrate\n",
    "\n",
    "class LossHistory(callbacks.Callback):\n",
    "    def on_train_begin(self, logs={}):\n",
    "       self.losses = []\n",
    "       self.lr = []\n",
    " \n",
    "    def on_epoch_end(self, batch, logs={}):\n",
    "       self.losses.append(logs.get('loss'))\n",
    "       self.lr.append(step_decay(len(self.losses)))\n",
    "        \n",
    "loss_history = LossHistory()\n",
    "lrate = callbacks.LearningRateScheduler(step_decay)\n",
    "callbacks_list = [loss_history, lrate]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.metrics import top_k_categorical_accuracy\n",
    "#opt = tf.keras.optimizers.SGD(momentum=0.9, nesterov=True)\n",
    "opt = tf.keras.optimizers.Adam()\n",
    "model.compile(optimizer=opt, loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy','top_k_categorical_accuracy'])\n",
    "\n",
    "# Validation time\n",
    "history = model.fit(X_train, y_train, epochs=5, batch_size = 32, validation_data=(X_test, y_test), callbacks=callbacks_list, verbose=1)\n",
    "\n",
    "# Final Test time\n",
    "#history = model.fit(trainX, trainY, epochs=5, callbacks=callbacks_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Learning rates through epochs:', loss_history.lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.models import Model\n",
    "\n",
    "# Getting the DNA embeddings of all data from the last dense layer\n",
    "layer_name = 'dense'\n",
    "model2= Model(inputs=model.input, outputs=model.get_layer(layer_name).output)\n",
    "dna_embeddings=model2.predict(allX)\n",
    "print(dna_embeddings.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# saving the results into csv file to be later read into matlab\n",
    "np.savetxt(\"nips_cnn_embeddings_500_5e_adam_aligned.csv\", dna_embeddings, delimiter=\",\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
