{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "36b8f9bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data import TrainDataset, TestDataset\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from models import SAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "6337c1f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helper function to calculate accuracy\n",
    "def calculate_accuracy(outputs, targets):\n",
    "    _, predicted = torch.max(outputs, 1)\n",
    "    correct = (predicted == targets).sum().item()\n",
    "    total = targets.size(0)\n",
    "    return correct / total\n",
    "\n",
    "\n",
    "\n",
    "# define train and validation function \n",
    "def train(model, train_dataloader, \n",
    "                IC_middle_first_dataloader,\n",
    "                IC_first_middle_dataloader,\n",
    "                IC_first_last_dataloader,\n",
    "                IC_last_first_dataloader,\n",
    "                IC_middle_last_dataloader,\n",
    "                IC_last_middle_dataloader,\n",
    "                criterion, optimizer, \n",
    "                device, num_iter=150000, rec_freq=1):\n",
    "    \n",
    "    model.train()\n",
    "    for i, (inputs, targets) in enumerate(train_dataloader):\n",
    "        if i >= num_iter:\n",
    "            break\n",
    "\n",
    "        inputs, targets = inputs.to(device), targets.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if i % rec_freq == 0:\n",
    "            IC_middle_first_loss, IC_middle_first_acc = validate(model, IC_middle_first_dataloader, criterion, device)\n",
    "            IC_first_middle_loss, IC_first_middle_acc = validate(model, IC_first_middle_dataloader, criterion, device)\n",
    "            IC_first_last_loss, IC_first_last_acc = validate(model, IC_first_last_dataloader, criterion, device)\n",
    "            IC_last_first_loss, IC_last_first_acc = validate(model, IC_last_first_dataloader, criterion, device)\n",
    "            IC_middle_last_loss, IC_middle_last_acc = validate(model, IC_middle_last_dataloader, criterion, device)\n",
    "            IC_last_middle_loss, IC_last_middle_acc = validate(model, IC_last_middle_dataloader, criterion, device)\n",
    "        \n",
    "            print(f\"Iter {i}:\"\n",
    "                  f\"IC First Middle Acc - Middle First Acc: {IC_first_middle_acc - IC_middle_first_acc:.4f}, \"\n",
    "                  f\"IC First Last Acc - IC Last First Acc: {IC_first_last_acc - IC_last_first_acc:.4f}, \"\n",
    "                  f\"IC Middle Last Acc - IC Last Middle Acc: {IC_middle_last_acc - IC_last_middle_acc:.4f}, \")\n",
    "    \n",
    "\n",
    "def validate(model, dataloader, criterion, device):\n",
    "    model.eval()\n",
    "    running_loss = 0.0\n",
    "    running_accuracy = 0.0\n",
    "    with torch.no_grad():\n",
    "        for inputs, targets in dataloader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, targets)\n",
    "            running_loss += loss.item()\n",
    "            running_accuracy += calculate_accuracy(outputs, targets)\n",
    "    return running_loss / len(dataloader), running_accuracy / len(dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "70411f2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify the training dataset arguments: no position bias\n",
    "N = 8\n",
    "num_classes = 2048\n",
    "dim_features = 64\n",
    "B = 4\n",
    "eps = 0.75\n",
    "pos_bias = False\n",
    "index = [0]\n",
    "test_size = 10000\n",
    "\n",
    "\n",
    "\n",
    "# specify the training arguments\n",
    "bs = 128\n",
    "lr = 1e-3\n",
    "wd = 1e-6 \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c0632f8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the datasets; note that this takes quite a while to run\n",
    "train_dataset = TrainDataset(N=N, K=num_classes, D=dim_features, B=B, eps=eps, pos_bias=pos_bias, index=index)\n",
    "IC_first_middle_dataset = TestDataset(num_seqs=test_size, test_type='IC_first_middle', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)\n",
    "IC_middle_first_dataset = TestDataset(num_seqs=test_size, test_type='IC_middle_first', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)\n",
    "IC_first_last_dataset = TestDataset(num_seqs=test_size, test_type='IC_first_last', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)\n",
    "IC_last_first_dataset = TestDataset(num_seqs=test_size, test_type='IC_last_first', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)\n",
    "IC_middle_last_dataset = TestDataset(num_seqs=test_size, test_type='IC_middle_last', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)\n",
    "IC_last_middle_dataset = TestDataset(num_seqs=test_size, test_type='IC_last_middle', train_dataset=train_dataset, N=N, K=num_classes, D=dim_features, B=B, eps=eps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e01c1178",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True)\n",
    "IC_first_middle_dataloader = DataLoader(IC_first_middle_dataset, batch_size=bs, shuffle=False)\n",
    "IC_middle_first_dataloader = DataLoader(IC_middle_first_dataset, batch_size=bs, shuffle=False)\n",
    "IC_first_last_dataloader = DataLoader(IC_first_last_dataset, batch_size=bs, shuffle=False)\n",
    "IC_last_first_dataloader = DataLoader(IC_last_first_dataset, batch_size=bs, shuffle=False)\n",
    "IC_middle_last_dataloader = DataLoader(IC_middle_last_dataset, batch_size=bs, shuffle=False)\n",
    "IC_last_middle_dataloader = DataLoader(IC_last_middle_dataset, batch_size=bs, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "073ab728",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# Initialize model, criterion, and optimizer\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "criterion = nn.CrossEntropyLoss()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "3f41b9c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "mask_type = \"causal\"    # Options: \"causal\", \"decay\"\n",
    "gamma = 1   # Options: 1 for causal, 0.8 for decay\n",
    "num_attn_layers = 10  # Options: 2 or 10\n",
    "\n",
    "\n",
    "model = SAN(in_channels=dim_features, hidden_channels=dim_features, out_channels=32, mask_type=mask_type, gamma=gamma, num_attn_layers=num_attn_layers).to(device)\n",
    "optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "ead67569",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 0:IC First Middle Acc - Middle First Acc: 0.0025, IC First Last Acc - IC Last First Acc: 0.0031, IC Middle Last Acc - IC Last Middle Acc: 0.0046, \n",
      "Iter 100:IC First Middle Acc - Middle First Acc: 0.0949, IC First Last Acc - IC Last First Acc: 0.0650, IC Middle Last Acc - IC Last Middle Acc: 0.0037, \n",
      "Iter 200:IC First Middle Acc - Middle First Acc: 0.0874, IC First Last Acc - IC Last First Acc: 0.1411, IC Middle Last Acc - IC Last Middle Acc: 0.0160, \n",
      "Iter 300:IC First Middle Acc - Middle First Acc: 0.1311, IC First Last Acc - IC Last First Acc: 0.1279, IC Middle Last Acc - IC Last Middle Acc: 0.0111, \n",
      "Iter 400:IC First Middle Acc - Middle First Acc: 0.1420, IC First Last Acc - IC Last First Acc: 0.1346, IC Middle Last Acc - IC Last Middle Acc: 0.0148, \n",
      "Iter 500:IC First Middle Acc - Middle First Acc: 0.1224, IC First Last Acc - IC Last First Acc: 0.1071, IC Middle Last Acc - IC Last Middle Acc: 0.0122, \n",
      "Iter 600:IC First Middle Acc - Middle First Acc: 0.1361, IC First Last Acc - IC Last First Acc: 0.1079, IC Middle Last Acc - IC Last Middle Acc: 0.0144, \n",
      "Iter 700:IC First Middle Acc - Middle First Acc: 0.1483, IC First Last Acc - IC Last First Acc: 0.1112, IC Middle Last Acc - IC Last Middle Acc: 0.0218, \n",
      "Iter 800:IC First Middle Acc - Middle First Acc: 0.1496, IC First Last Acc - IC Last First Acc: 0.1419, IC Middle Last Acc - IC Last Middle Acc: 0.0210, \n",
      "Iter 900:IC First Middle Acc - Middle First Acc: 0.1572, IC First Last Acc - IC Last First Acc: 0.1509, IC Middle Last Acc - IC Last Middle Acc: 0.0302, \n"
     ]
    }
   ],
   "source": [
    "# train! \n",
    "\n",
    "train(model, \n",
    "    train_dataloader, \n",
    "    IC_middle_first_dataloader, \n",
    "    IC_first_middle_dataloader,\n",
    "    IC_first_last_dataloader,\n",
    "    IC_last_first_dataloader,\n",
    "    IC_middle_last_dataloader,\n",
    "    IC_last_middle_dataloader,\n",
    "    criterion, \n",
    "    optimizer, \n",
    "    device, \n",
    "    num_iter=1000, rec_freq=100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
