{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from omegaconf import OmegaConf\n",
    "from torchfly.flyconfig import GlobalFlyConfig\n",
    "\n",
    "from model.modules.cross_attention import CrossAttention\n",
    "from model.decoder import DecoderLayer, Decoder\n",
    "from model.encoder_decoder import EncoderDecoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test Cross Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CrossAttention(8, 512, 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "cross_hidden_states = torch.randn(2, 10, 512)\n",
    "hidden_states = torch.randn(2, 64, 512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0397, grad_fn=<StdBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(hidden_states, cross_hidden_states)[1].std()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test Decoder Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DecoderLayer(8, 512, 64, 2048, 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "cross_hidden_states = torch.randn(2, 10, 512)\n",
    "hidden_states = torch.randn(2, 64, 512)\n",
    "decoder_cache = {\"past_hidden_states\": None}\n",
    "rel_pos_embedding = torch.randn(64, 64, 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model(hidden_states, cross_hidden_states, rel_pos_embedding, decoder_cache, None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test Decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = OmegaConf.load(\"config/model/base.yml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "config.model.decoder_cache_len = 0\n",
    "model = Decoder(config.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 2\n",
    "source_len = 10\n",
    "target_len = 63\n",
    "cross_hidden_states = torch.randn(batch_size, source_len, 512)\n",
    "hidden_states = torch.randn(batch_size, target_len, 512)\n",
    "decoder_cache = [{\"past_hidden_states\": None} for i in range(config.model.num_decoder_layers)]\n",
    "encoder_attn_mask = torch.ones(batch_size, source_len).bool()\n",
    "decoder_attn_mask = torch.ones(batch_size, target_len).bool()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "model(hidden_states, cross_hidden_states, decoder_cache, decoder_attn_mask, encoder_attn_mask)\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test EncoderDecoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = OmegaConf.load(\"config/model/base.yml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = EncoderDecoder(config.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_input_ids = torch.randint(10, (batch_size, source_len))\n",
    "decoder_input_ids = torch.randint(10, (batch_size, target_len))\n",
    "decoder_caches = [{\"past_hidden_states\": torch.randn(batch_size, config.model.decoder_cache_len, config.model.dim_model)} \n",
    "                  for i in range(config.model.num_decoder_layers)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "76.210688"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(p.numel() for p in model.parameters()) / 1000000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs, decoder_cache = model(encoder_input_ids, decoder_input_ids, decoder_caches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 32, 512])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decoder_cache[0][\"past_hidden_states\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
