{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "72ac06a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_sequence, unpack_sequence, pad_packed_sequence\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import ast\n",
    "\n",
    "torch.manual_seed(42)\n",
    "\n",
    "\n",
    "DEVICE = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d48cb91d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMModel(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, number_of_layers, output_dim, bidirectional=False, device=\"cpu\"):\n",
    "        super(LSTMModel, self).__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.number_of_layers = number_of_layers\n",
    "        self.lstm = nn.LSTM(input_dim, hidden_dim, number_of_layers, batch_first=True, device=device, bidirectional=bidirectional)\n",
    "        self.fc = nn.Linear(hidden_dim, output_dim, device=device)\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "\n",
    "    def forward(self, x):\n",
    "        out, (hn, cn) = self.lstm(x)\n",
    "        final_hidden = hn[-1]\n",
    "        # out, lens = pad_packed_sequence(out)\n",
    "        # print(lens-1)\n",
    "        # out = out[lens-1, ...]  # only take the output of the last element in the series which is the output of the lstm\n",
    "        # out = out[-1, ...]\n",
    "        out = self.sigmoid(final_hidden)\n",
    "        out = self.fc(out)\n",
    "        out = self.sigmoid(out)\n",
    "        return out\n",
    "    \n",
    "class LogisticRegressionModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LogisticRegressionModel, self).__init__()\n",
    "        self.fully_connected1 = nn.LazyLinear(1, device=DEVICE)\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        # out = x/14\n",
    "        out = self.fully_connected1(x)\n",
    "        out = self.sigmoid(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "75b32649",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize Datasets\n",
    "\n",
    "class TelescopeSequenceModelingDataset(Dataset):\n",
    "    def __init__(self, dataset_file):\n",
    "        df = pd.read_csv(dataset_file)\n",
    "        \n",
    "        self.sequences_list = [torch.tensor(ast.literal_eval(string_sequence), dtype=torch.float32, device=DEVICE).view(-1, 1) for string_sequence in df[\"telescope_perplexity_per_token\"]] \n",
    "        self.labels = torch.tensor(df[\"labels\"].to_numpy(), dtype=torch.float32, device=DEVICE) \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.sequences_list)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.sequences_list[idx], self.labels[idx]\n",
    "\n",
    "\n",
    "class FullSequenceModelingDataset(Dataset):\n",
    "    def __init__(self, dataset_file):\n",
    "        df = pd.read_csv(dataset_file)\n",
    "\n",
    "        self.sequences_list = []\n",
    "        for telescope_perplexity_string_sequence, cross_perplexity_string_sequence, perplexity_string_sequence in zip(df[\"telescope_perplexity_per_token\"], df[\"cross_perplexity_per_token\"], df[\"perplexity_per_token\"]):\n",
    "            \n",
    "            sequence = torch.tensor(\n",
    "                list(zip(ast.literal_eval(telescope_perplexity_string_sequence)[0], ast.literal_eval(cross_perplexity_string_sequence)[0][:-1], ast.literal_eval(perplexity_string_sequence)[0])), \n",
    "                dtype=torch.float32, device=DEVICE\n",
    "            )\n",
    "            \n",
    "            sequence = sequence / 14  # normalize for the neural network\n",
    "                        \n",
    "            self.sequences_list.append(sequence)        \n",
    "        \n",
    "        self.labels = torch.tensor(df[\"labels\"].to_numpy(), dtype=torch.float32, device=DEVICE) \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.sequences_list)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.sequences_list[idx], self.labels[idx]\n",
    "\n",
    "\n",
    "\n",
    "class TelescopeAverageDataset(Dataset):\n",
    "    def __init__(self, dataset_file):\n",
    "        df = pd.read_csv(dataset_file)\n",
    "        \n",
    "        self.sequences_list = torch.tensor([np.average(ast.literal_eval(string_sequence)) for string_sequence in df[\"telescope_perplexity_per_token\"]], dtype=torch.float32, device=DEVICE).view(-1, 1)\n",
    "        self.labels = torch.tensor(df[\"labels\"].to_numpy(), dtype=torch.float32, device=DEVICE) \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.sequences_list)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.sequences_list[idx], self.labels[idx]\n",
    "\n",
    "\n",
    "# pack batches properly so that the LSTM module can properly use it\n",
    "def collate_fn(batch: list[tuple[torch.Tensor, float]]) -> torch.nn.utils.rnn.PackedSequence:\n",
    "    packed_sequences = pack_sequence([sequence[0] for sequence in batch], enforce_sorted=False)\n",
    "    labels = torch.tensor([sequence[1] for sequence in batch], device=DEVICE)\n",
    "    return packed_sequences, labels\n",
    "        \n",
    "        \n",
    "        \n",
    "dataset_list = []\n",
    "for dataset_name in [\"hc3_plus_smollm_360M_dataset\", \"hc3_smollm_360M_dataset\", \"ai_human_smollm_360M_dataset\", \"detectllmtext_smollm_360M_dataset\", \"esl_gpt4o_smollm_360M_dataset\"]:\n",
    "    dataset_list.append(TelescopeSequenceModelingDataset(f\"sequence_modeling_datasets/{dataset_name}/full.csv\"))\n",
    "full_dataset = torch.utils.data.ConcatDataset(dataset_list)\n",
    "\n",
    "train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [0.8, 0.2])\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)\n",
    "\n",
    "\n",
    "     \n",
    "# dataset_list = []\n",
    "# for dataset_name in [\"hc3_plus_smollm_360M_dataset\", \"hc3_smollm_360M_dataset\", \"ai_human_smollm_360M_dataset\", \"detectllmtext_smollm_360M_dataset\", \"esl_gpt4o_smollm_360M_dataset\"]:\n",
    "#     dataset_list.append(TelescopeAverageDataset(f\"sequence_modeling_datasets/{dataset_name}/full.csv\"))\n",
    "# full_dataset = torch.utils.data.ConcatDataset(dataset_list)\n",
    "\n",
    "# train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [0.8, 0.2])\n",
    "# train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)\n",
    "# test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=True)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# full_dataset = FullSequenceModelingDataset(\"sequence_modeling_datasets/hc3_plus_smollm_360M_dataset/full.csv\") \n",
    "# train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [0.8, 0.2])\n",
    "# train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)\n",
    "# test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6cd6ef1e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/30], Loss: 0.6912\n",
      "Epoch [1/30], Loss: 0.6815\n",
      "Epoch [1/30], Loss: 0.7041\n"
     ]
    },
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 3.74 GiB. GPU 0 has a total capacity of 7.75 GiB of which 1.97 GiB is free. Including non-PyTorch memory, this process has 5.51 GiB memory in use. Of the allocated memory 1.78 GiB is allocated by PyTorch, and 3.57 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 18\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_index, (batch_data, batch_labels) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(train_dataloader):\n\u001b[1;32m     16\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 18\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_data\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     20\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     22\u001b[0m     loss \u001b[38;5;241m=\u001b[39m loss_function(outputs, batch_labels)\n",
      "File \u001b[0;32m~/miniconda3/envs/telescope/lib/python3.9/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/telescope/lib/python3.9/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "Cell \u001b[0;32mIn[2], line 11\u001b[0m, in \u001b[0;36mLSTMModel.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 11\u001b[0m     out, (hn, cn) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     12\u001b[0m     final_hidden \u001b[38;5;241m=\u001b[39m hn[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m     13\u001b[0m     \u001b[38;5;66;03m# out, lens = pad_packed_sequence(out)\u001b[39;00m\n\u001b[1;32m     14\u001b[0m     \u001b[38;5;66;03m# print(lens-1)\u001b[39;00m\n\u001b[1;32m     15\u001b[0m     \u001b[38;5;66;03m# out = out[lens-1, ...]  # only take the output of the last element in the series which is the output of the lstm\u001b[39;00m\n\u001b[1;32m     16\u001b[0m     \u001b[38;5;66;03m# out = out[-1, ...]\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/telescope/lib/python3.9/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/telescope/lib/python3.9/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[0;32m~/miniconda3/envs/telescope/lib/python3.9/site-packages/torch/nn/modules/rnn.py:1135\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m   1123\u001b[0m     result \u001b[38;5;241m=\u001b[39m _VF\u001b[38;5;241m.\u001b[39mlstm(\n\u001b[1;32m   1124\u001b[0m         \u001b[38;5;28minput\u001b[39m,\n\u001b[1;32m   1125\u001b[0m         hx,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1132\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_first,\n\u001b[1;32m   1133\u001b[0m     )\n\u001b[1;32m   1134\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1135\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1136\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1137\u001b[0m \u001b[43m        \u001b[49m\u001b[43mbatch_sizes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1138\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1139\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1140\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1141\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1142\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1143\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1144\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1145\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1146\u001b[0m output \u001b[38;5;241m=\u001b[39m result[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m   1147\u001b[0m hidden \u001b[38;5;241m=\u001b[39m result[\u001b[38;5;241m1\u001b[39m:]\n",
      "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 3.74 GiB. GPU 0 has a total capacity of 7.75 GiB of which 1.97 GiB is free. Including non-PyTorch memory, this process has 5.51 GiB memory in use. Of the allocated memory 1.78 GiB is allocated by PyTorch, and 3.57 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)"
     ]
    }
   ],
   "source": [
    "num_epochs = 30\n",
    "    \n",
    "model = LSTMModel(input_dim=1, hidden_dim=300, number_of_layers=3, output_dim=1, device=DEVICE, bidirectional=True)\n",
    "# model = LogisticRegressionModel()\n",
    "\n",
    "model.train()\n",
    "loss_function = nn.BCELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.00003)\n",
    "\n",
    "scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.98)\n",
    "# scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.01, total_iters=num_epochs)\n",
    "\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for batch_index, (batch_data, batch_labels) in enumerate(train_dataloader):\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        outputs = model(batch_data)\n",
    "        \n",
    "        outputs = outputs.view(-1)\n",
    "        \n",
    "        loss = loss_function(outputs, batch_labels)\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        \n",
    "        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "fa4424b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9609375\n",
      "0.95703125\n",
      "0.9583333333333334\n",
      "0.962890625\n",
      "0.96484375\n",
      "0.9674479166666666\n",
      "0.9659598214285714\n",
      "0.96826171875\n",
      "0.9670138888888888\n",
      "0.965234375\n",
      "0.9659090909090909\n",
      "0.9674479166666666\n",
      "0.9678485576923077\n",
      "0.9679129464285714\n",
      "0.9671875\n",
      "0.968505859375\n",
      "0.96875\n",
      "0.9678819444444444\n",
      "0.9681332236842105\n",
      "0.968359375\n",
      "0.9681919642857143\n",
      "0.9676846590909091\n",
      "0.9684103260869565\n",
      "0.96826171875\n",
      "0.9684375\n",
      "0.96875\n",
      "0.9691840277777778\n",
      "0.9698660714285714\n",
      "0.9705010775862069\n",
      "0.9705729166666667\n",
      "0.9706401209677419\n",
      "0.9708251953125\n",
      "0.9707623106060606\n",
      "0.9705882352941176\n",
      "0.9703125\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch_index, (batch_data, batch_labels) in enumerate(test_dataloader):\n",
    "        \n",
    "        if batch_labels.shape[0] != 256: continue\n",
    "        \n",
    "        # print(batch_data)\n",
    "        batch_output = model(batch_data)\n",
    "        \n",
    "        batch_output = batch_output.view(-1)\n",
    "        batch_output = batch_output.cpu().numpy()\n",
    "        \n",
    "        batch_labels = batch_labels.cpu().numpy()        \n",
    "\n",
    "        preds = batch_output > 0.5\n",
    "        labels = batch_labels > 0.5\n",
    "        correct += np.count_nonzero(preds == labels)\n",
    "         \n",
    "        total += 256\n",
    "        \n",
    "        print(correct/total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86d836ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model, 'models/smollm_360M_lstm_extra_features_all_datasets.pt')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26193faf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_5982/2113162760.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model = torch.load(\"detectllmtext_hc3_plus_smollm_360M_lstm.pt\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "128\n",
      "0.984375\n",
      "128\n",
      "0.98828125\n",
      "128\n",
      "0.9869791666666666\n",
      "128\n",
      "0.98828125\n",
      "128\n",
      "0.9859375\n",
      "128\n",
      "0.98828125\n",
      "128\n",
      "0.9899553571428571\n",
      "128\n",
      "0.98828125\n",
      "128\n",
      "0.9869791666666666\n",
      "128\n",
      "0.98515625\n",
      "128\n",
      "0.9865056818181818\n",
      "128\n",
      "0.9869791666666666\n",
      "128\n",
      "0.9879807692307693\n",
      "128\n",
      "0.9888392857142857\n",
      "128\n",
      "0.9895833333333334\n",
      "128\n",
      "0.98974609375\n",
      "128\n",
      "0.9898897058823529\n",
      "128\n",
      "0.9900173611111112\n",
      "128\n",
      "0.9905427631578947\n",
      "128\n",
      "0.990234375\n",
      "128\n",
      "0.9899553571428571\n",
      "128\n",
      "0.9904119318181818\n",
      "128\n",
      "0.9898097826086957\n",
      "128\n",
      "0.9889322916666666\n",
      "128\n",
      "0.9890625\n",
      "128\n",
      "0.9894831730769231\n",
      "128\n",
      "0.9898726851851852\n",
      "128\n",
      "0.9899553571428571\n",
      "128\n",
      "0.9900323275862069\n",
      "128\n",
      "0.98984375\n",
      "128\n",
      "0.9896673387096774\n",
      "32\n"
     ]
    }
   ],
   "source": [
    "# model = torch.load(\"models/hc3_plus_smollm_360M_model.pt\")\n",
    "model = torch.load(\"models/detectllmtext_hc3_plus_smollm_360M_lstm.pt\")\n",
    "\n",
    "\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch_index, (batch_data, batch_labels) in enumerate(test_dataloader):\n",
    "        \n",
    "        print(batch_labels.shape[0])\n",
    "        if batch_labels.shape[0] != 128: continue\n",
    "        \n",
    "        # print(batch_data)\n",
    "        batch_output = model(batch_data)\n",
    "        \n",
    "        batch_output = batch_output.view(-1)\n",
    "        batch_output = batch_output.cpu().numpy()\n",
    "        \n",
    "        batch_labels = batch_labels.cpu().numpy()        \n",
    "\n",
    "        preds = batch_output > 0.5\n",
    "        labels = batch_labels > 0.5\n",
    "        correct += np.count_nonzero(preds == labels)\n",
    "         \n",
    "        total += 128\n",
    "        \n",
    "        print(correct/total)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "telescope",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
