{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U-b3x4I4u97S"
      },
      "outputs": [],
      "source": [
        "import pandas as pd"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VKsy6S0iwXu4"
      },
      "outputs": [],
      "source": [
        "df = pd.DataFrame(pd.read_pickle('train_data.pkl'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5e0WX1OGvocv"
      },
      "outputs": [],
      "source": [
        "len(df)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yksDQB_BxQdG"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "\n",
        "X = np.stack(df['confidence'].values)\n",
        "y = np.stack(df['is_correct'].apply(lambda x: [int(b) for b in x]))\n",
        "\n",
        "X = torch.tensor(X, dtype=torch.float32)\n",
        "y = torch.tensor(y, dtype=torch.float32)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "df_val = pd.read_pickle('val_data.pkl')\n",
        "df_val = pd.DataFrame(df_val)\n",
        "X_val = np.stack(df_val['confidence'].values)\n",
        "y_val = np.stack(df_val['is_correct'].apply(lambda x: [int(b) for b in x]))\n",
        "\n",
        "X_val = torch.tensor(X_val, dtype=torch.float32)\n",
        "y_val = torch.tensor(y_val, dtype=torch.float32)"
      ],
      "metadata": {
        "id": "fmBirSW36OGy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JpnjynsUxSxL"
      },
      "outputs": [],
      "source": [
        "import torch.nn as nn\n",
        "\n",
        "class LogisticRegression(nn.Module):\n",
        "    def __init__(self, input_dim):\n",
        "        super().__init__()\n",
        "        self.linear1 = nn.Linear(input_dim, input_dim)\n",
        "        self.relu = nn.ReLU()\n",
        "        self.linear2 = nn.Linear(input_dim, input_dim)\n",
        "        # self.linear3 = nn.Linear(input_dim, input_dim)\n",
        "        # self.linear4 = nn.Linear(input_dim, input_dim)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.linear1(x)\n",
        "        x = self.relu(x)\n",
        "        x = self.linear2(x)\n",
        "        # x = self.relu(x)\n",
        "        # x = self.linear3(x)\n",
        "        # x = self.relu(x)\n",
        "        # x = self.linear4(x)\n",
        "        return x\n",
        "\n",
        "model = LogisticRegression(input_dim=32)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KjUG5DghxU14"
      },
      "outputs": [],
      "source": [
        "loss_fn = nn.BCEWithLogitsLoss()\n",
        "optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mYeiY2BJxXnX"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "import numpy as np\n",
        "\n",
        "losses = []\n",
        "val_losses = []\n",
        "\n",
        "for epoch in range(5000):\n",
        "    optimizer.zero_grad()\n",
        "    logits = model(X)\n",
        "    loss = loss_fn(logits, y)\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "    current_loss = loss.item()\n",
        "    losses.append(current_loss)\n",
        "\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        val_logits = model(X_val)\n",
        "        val_loss = loss_fn(val_logits, y_val).item()\n",
        "        val_losses.append(val_loss)\n",
        "\n",
        "    if epoch % 500 == 0:\n",
        "        print(f\"Epoch {epoch}: loss={loss:.4f}, Val_loss={val_loss:.4f}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "plt.figure(figsize=(6, 6))\n",
        "plt.plot(losses, label='Training Loss', color='#0066CC')\n",
        "plt.plot(val_losses, label='Validation Loss', color='red')\n",
        "\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Loss')\n",
        "plt.legend()\n",
        "plt.grid(True)\n",
        "plt.gca().set_facecolor('#f8f8f8')\n",
        "\n",
        "\n",
        "plt.savefig('learning_curve.png', dpi=300, bbox_inches='tight')"
      ],
      "metadata": {
        "id": "IYe4pibb7fjz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8MOfCgFDxtBw"
      },
      "outputs": [],
      "source": [
        "# model.load_state_dict(torch.load(\"/content/layer_1_flan.pth\"))\n",
        "df_test = pd.read_pickle('test_data.pkl')\n",
        "df_test = pd.DataFrame(df_test)\n",
        "X_test = np.stack(df_test['confidence'].values)\n",
        "y_test = np.stack(df_test['is_correct'].apply(lambda x: [int(b) for b in x]))\n",
        "\n",
        "X_test = torch.tensor(X_test, dtype=torch.float32)\n",
        "y_test = torch.tensor(y_test, dtype=torch.float32)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m1PvDvGWxgpt"
      },
      "outputs": [],
      "source": [
        "with torch.no_grad():\n",
        "    logits = model(X_test)\n",
        "    probs = torch.sigmoid(logits)\n",
        "    preds = (probs > 0.96).float()\n",
        "    acc = (preds > y_test).float().mean().item()\n",
        "    print(\"Accuracy:\", acc)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8m36UkYxxsLv"
      },
      "outputs": [],
      "source": [
        "torch.save(model.state_dict(), 'latest_filter.pth')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}