{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "cdff9570",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "d69a7b3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('/vision/u/eatang/leaky_video/datasets/epic_kitchens/narration_id_to_prev_actions.pkl', 'rb') as f:\n",
    "    narration_id_to_prev_actions = pickle.load(f)\n",
    "    \n",
    "with open('/vision/u/eatang/leaky_video/datasets/epic_kitchens/narration_id_to_prev_actions_val.pkl', 'rb') as f:\n",
    "    narration_id_to_prev_actions_val = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "2a8e843c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyperparameters\n",
    "\n",
    "num_prev_actions = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "9a92cb67",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = np.zeros((len(narration_id_to_prev_actions), num_prev_actions, 3806))\n",
    "y_train = np.zeros((len(narration_id_to_prev_actions)))\n",
    "for i, (k, v) in enumerate(narration_id_to_prev_actions.items()):\n",
    "    for j in range(-num_prev_actions - 1, -1, 1):\n",
    "        label = v[j]\n",
    "        x_train[i, j + 1] = torch.nn.functional.one_hot(torch.LongTensor([label]), num_classes=3806).squeeze().numpy()\n",
    "    y_train[i] = v[-1]\n",
    "    \n",
    "x_val = np.zeros((len(narration_id_to_prev_actions_val), num_prev_actions, 3806))\n",
    "y_val = np.zeros((len(narration_id_to_prev_actions_val)))\n",
    "for i, (k, v) in enumerate(narration_id_to_prev_actions_val.items()):\n",
    "    for j in range(-num_prev_actions - 1, -1, 1):\n",
    "        label = v[j]\n",
    "        x_val[i, j + 1] = torch.nn.functional.one_hot(torch.LongTensor([label]), num_classes=3806).squeeze().numpy()\n",
    "    y_val[i] = v[-1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "da6202ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# more hyperparameters\n",
    "embedding_dim = 512  # Adjust the embedding dimension as needed\n",
    "use_embeddings = False\n",
    "num_epochs = 20\n",
    "batch_size = 256\n",
    "lr = 0.001\n",
    "hidden_dim = 512\n",
    "weight_decay = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "d701069f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Train Loss: 6.2414, Test Loss: 6.3295, Top-1 Accuracy: 0.0263\n",
      "Epoch [2/20], Train Loss: 6.0214, Test Loss: 6.1570, Top-1 Accuracy: 0.0609\n",
      "Epoch [3/20], Train Loss: 5.6416, Test Loss: 5.8927, Top-1 Accuracy: 0.0896\n",
      "Epoch [4/20], Train Loss: 5.3211, Test Loss: 5.6954, Top-1 Accuracy: 0.1065\n",
      "Epoch [5/20], Train Loss: 5.0266, Test Loss: 5.5366, Top-1 Accuracy: 0.1209\n",
      "Epoch [6/20], Train Loss: 4.8795, Test Loss: 5.4036, Top-1 Accuracy: 0.1310\n",
      "Epoch [7/20], Train Loss: 4.6885, Test Loss: 5.2988, Top-1 Accuracy: 0.1391\n",
      "Epoch [8/20], Train Loss: 4.4868, Test Loss: 5.2133, Top-1 Accuracy: 0.1436\n",
      "Epoch [9/20], Train Loss: 4.3790, Test Loss: 5.1434, Top-1 Accuracy: 0.1496\n",
      "Epoch [10/20], Train Loss: 4.1840, Test Loss: 5.0942, Top-1 Accuracy: 0.1510\n",
      "Epoch [11/20], Train Loss: 4.1271, Test Loss: 5.0635, Top-1 Accuracy: 0.1557\n",
      "Epoch [12/20], Train Loss: 4.0599, Test Loss: 5.0459, Top-1 Accuracy: 0.1589\n",
      "Epoch [13/20], Train Loss: 3.9583, Test Loss: 5.0358, Top-1 Accuracy: 0.1602\n",
      "Epoch [14/20], Train Loss: 3.9302, Test Loss: 5.0227, Top-1 Accuracy: 0.1612\n",
      "Epoch [15/20], Train Loss: 3.8825, Test Loss: 5.0161, Top-1 Accuracy: 0.1615\n",
      "Epoch [16/20], Train Loss: 3.8073, Test Loss: 5.0078, Top-1 Accuracy: 0.1645\n",
      "Epoch [17/20], Train Loss: 3.6319, Test Loss: 5.0039, Top-1 Accuracy: 0.1644\n",
      "Epoch [18/20], Train Loss: 3.6759, Test Loss: 5.0012, Top-1 Accuracy: 0.1653\n",
      "Epoch [19/20], Train Loss: 3.5381, Test Loss: 5.0014, Top-1 Accuracy: 0.1692\n",
      "Epoch [20/20], Train Loss: 3.5174, Test Loss: 5.0045, Top-1 Accuracy: 0.1707\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# Assuming x is your input data and y is your target data\n",
    "# x should be a NumPy array or a PyTorch tensor\n",
    "# y should be a NumPy array or a PyTorch tensor with the same number of rows as x\n",
    "\n",
    "x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.2, random_state=42)\n",
    "\n",
    "if use_embeddings:\n",
    "    # Convert x_train and x_test to PyTorch tensors if they're not already\n",
    "    x_train_torch = torch.LongTensor(x_train.argmax(-1))\n",
    "    x_test_torch = torch.LongTensor(x_test.argmax(-1))\n",
    "else:\n",
    "    x_train_torch = torch.FloatTensor(x_train.sum(1))\n",
    "    x_test_torch = torch.FloatTensor(x_test.sum(1))\n",
    "    \n",
    "# Convert y_train and y_test to PyTorch tensors if they're not already\n",
    "y_train_torch = torch.LongTensor(y_train)  # Assuming y contains class indices, change accordingly if needed\n",
    "y_test_torch = torch.LongTensor(y_test)\n",
    "\n",
    "device = 'cuda:0'\n",
    "\n",
    "# Define your MLP with an embedding layer\n",
    "\n",
    "if use_embeddings:\n",
    "    mlp = nn.Sequential(\n",
    "        nn.Embedding(num_embeddings=3806, embedding_dim=embedding_dim),\n",
    "        nn.Flatten(),\n",
    "        nn.Linear(embedding_dim*num_prev_actions, hidden_dim),\n",
    "        nn.ReLU(),\n",
    "        nn.Dropout(0.5),\n",
    "        nn.Linear(hidden_dim, hidden_dim),\n",
    "        nn.ReLU(),\n",
    "        nn.Dropout(0.5),\n",
    "        nn.Linear(hidden_dim, 3806)\n",
    "    )\n",
    "else:\n",
    "    mlp = nn.Sequential(\n",
    "        nn.Linear(3806, hidden_dim),\n",
    "        nn.ReLU(),\n",
    "        nn.Dropout(0.5),\n",
    "        nn.Linear(hidden_dim, hidden_dim),\n",
    "        nn.ReLU(),\n",
    "        nn.Dropout(0.5),\n",
    "        nn.Linear(hidden_dim, 3806)\n",
    "    )\n",
    "\n",
    "mlp.to(device)\n",
    "\n",
    "# Define your loss function and optimizer\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.AdamW(mlp.parameters(), lr=lr, weight_decay=weight_decay)  # Adjust the learning rate as needed\n",
    "\n",
    "# Training loop\n",
    "for epoch in range(num_epochs):\n",
    "#     if epoch % (num_epochs // 2) == 0:\n",
    "#         optimizer.param_groups[0][\"lr\"] = optimizer.param_groups[0][\"lr\"] / 2\n",
    "    # Training phase\n",
    "    mlp.train()\n",
    "    for i in range(0, len(x_train_torch), batch_size):\n",
    "        x_batch = x_train_torch[i:i+batch_size].to(device)\n",
    "        y_batch = y_train_torch[i:i+batch_size].to(device)\n",
    "\n",
    "        outputs = mlp(x_batch)\n",
    "        loss = criterion(outputs, y_batch)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    # Testing phase\n",
    "    mlp.eval()\n",
    "    with torch.no_grad():\n",
    "        test_outputs = mlp(x_test_torch.to(device))\n",
    "        test_loss = criterion(test_outputs, y_test_torch.to(device))\n",
    "        \n",
    "        test_outputs = test_outputs.detach().cpu()\n",
    "        \n",
    "        # Compute top-k accuracy\n",
    "        k = 1  # You can adjust k as needed\n",
    "        _, top_k_indices = test_outputs.topk(k, dim=1)\n",
    "        correct_predictions = top_k_indices.eq(y_test_torch.view(-1, 1).expand_as(top_k_indices))\n",
    "        top_k_accuracy = correct_predictions.sum().item() / len(y_test_torch)\n",
    "    \n",
    "    # Print the loss at the end of each epoch\n",
    "    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {loss.item():.4f}, Test Loss: {test_loss.item():.4f}, Top-{k} Accuracy: {top_k_accuracy:.4f}')\n",
    "\n",
    "# After training, you can use the trained model for inference\n",
    "# For example:\n",
    "# y_pred = mlp(x_test_torch)\n",
    "# Perform further actions as needed\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "42111572",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.float32"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mlp[0].weight.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "8a390abf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.float32"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_batch.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65a66989",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:root] *",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
