{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This Notebook is a demonstration of FTP with RNNs. \n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Import Libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "Eaxpqyzin0Aw"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Using device: cuda:0\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "from torch import nn\n",
        "import matplotlib.pyplot as plt\n",
        "import pandas as pd\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import torchvision\n",
        "from torchvision import datasets\n",
        "from torchvision.transforms import ToTensor\n",
        "\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.datasets import load_digits\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "from sklearn.model_selection import train_test_split\n",
        "import os\n",
        "\n",
        "gpu_id = 0  # set this to the index of the GPU you want to use\n",
        "device = torch.device(f\"cuda:{gpu_id}\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(f\"Using device: {device}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Data loader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {},
      "outputs": [],
      "source": [
        "\"\"\"Choose dataset: solar_AL for Solar, traffic for METR-LA, electricity for Electricity\"\"\"\n",
        "\n",
        "dataset = 'solar_AL'  # 'solar_AL', 'traffic', 'electricity'\n",
        "\n",
        "\n",
        "# Load the dataset\n",
        "data_path = os.path.join(\"data/\", dataset+\".txt\")\n",
        "df = pd.read_csv(data_path, header=None)  \n",
        "data = df.values  \n",
        "data = data.astype(np.float32)  \n",
        "\n",
        "max = np.max(data, axis=0)\n",
        "data_normalized = data / max\n",
        "\n",
        "normalization_params = {\n",
        "    'max': max,\n",
        "}\n",
        "\n",
        "# Define window size\n",
        "X_size = 24  # Number of timesteps (rows) in the feature window\n",
        "Y_size = 1   # Predict the next value (single row)\n",
        "\n",
        "# Create lists for X and Y\n",
        "X, Y = [], []\n",
        "\n",
        "# Apply the sliding window transformation on the column-normalized dataset\n",
        "for i in range(len(data_normalized) - X_size):\n",
        "    x_window = data_normalized[i:i+X_size] \n",
        "    y_value = data_normalized[i+X_size]     \n",
        "\n",
        "    X.append(x_window)\n",
        "    Y.append(y_value)\n",
        "\n",
        "# Convert to NumPy arrays\n",
        "X = np.array(X)  \n",
        "Y = np.array(Y) \n",
        "\n",
        "\n",
        "\n",
        "# # Create DataLoader objects\n",
        "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42, shuffle=False)\n",
        "\n",
        "# Convert to PyTorch tensors\n",
        "X_train_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
        "X_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n",
        "Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32).unsqueeze(1)  # Reshape Y to (batch_size, 1)\n",
        "Y_test_tensor = torch.tensor(Y_test, dtype=torch.float32).unsqueeze(1)  # Reshape Y to (batch_size, 1)\n",
        "\n",
        "# Custom PyTorch Dataset class for Time-Series\n",
        "class TimeSeriesDataset(Dataset):\n",
        "    def __init__(self, data, targets):\n",
        "        self.data = data\n",
        "        self.targets = targets\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return self.data[idx], self.targets[idx]\n",
        "\n",
        "# Create PyTorch Dataset objects\n",
        "train_dataset = TimeSeriesDataset(X_train_tensor, Y_train_tensor)\n",
        "test_dataset = TimeSeriesDataset(X_test_tensor, Y_test_tensor)\n",
        "\n",
        "# Create DataLoaders for training and testing\n",
        "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)\n",
        "test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)\n",
        "\n",
        "\n",
        "_, _, input_size = X_train_tensor.shape\n",
        "sequence_length = X_size\n",
        "numclasses = input_size\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sPnN9_Xob8ku"
      },
      "outputs": [],
      "source": [
        "\"\"\"Evaluation function\"\"\"\n",
        "\n",
        "\n",
        "def eval_model(model: torch.nn.Module,\n",
        "               data_loader: torch.utils.data.DataLoader,\n",
        "               loss_fn: torch.nn.Module):\n",
        "\n",
        "    total_loss = 0.0\n",
        "    total_samples = 0\n",
        "\n",
        "    model.eval()  \n",
        "\n",
        "    with torch.inference_mode():\n",
        "        for X, y in data_loader:\n",
        "\n",
        "            X, y = X.to(device), y.to(device)  \n",
        "            y_pred = model(X, True)\n",
        "\n",
        "            # Compute batch loss\n",
        "            batch_loss = loss_fn(y_pred, y)\n",
        "            total_loss += batch_loss.item() * X.size(0)  # Weighted sum of loss\n",
        "            total_samples += X.size(0)\n",
        "\n",
        "            # Compute absolute error for MAE\n",
        "            total_absolute_error += torch.sum(torch.abs(y_pred - y)).item()\n",
        "\n",
        "        # Compute average loss\n",
        "        avg_loss = total_loss / total_samples\n",
        "\n",
        "    return {\n",
        "        \"model_name\": model.__class__.__name__,  # Only works when model is defined as a class\n",
        "        \"model_loss\": avg_loss,\n",
        "    }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Network Architecture"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ymCXE9ZB0RBV",
        "outputId": "5b09c848-5f07-48b6-a3b7-1467e06082cf"
      },
      "outputs": [],
      "source": [
        "\n",
        "\n",
        "class simpleRNN(nn.Module):\n",
        "    def __init__(self, INPUT_SIZE, HIDDEN_SIZE, num_classes, sequence_length, learning_rate):\n",
        "        super().__init__()\n",
        "        \n",
        "        self.hidden_size = HIDDEN_SIZE\n",
        "        self.input_size = INPUT_SIZE\n",
        "        self.sequence_length = sequence_length\n",
        "        self.learning_rate = learning_rate\n",
        "        self.ln = torch.nn.LayerNorm(self.hidden_size)\n",
        "\n",
        "        # Xavier Initialization and move to device\n",
        "        self.Whh = self.xavier_initialization((HIDDEN_SIZE, HIDDEN_SIZE)).to(device)  \n",
        "        self.Wxh = self.xavier_initialization((HIDDEN_SIZE, INPUT_SIZE)).to(device)   \n",
        "        self.bh = torch.zeros((HIDDEN_SIZE, 1), device=device)  \n",
        "\n",
        "        self.Why = self.xavier_initialization((num_classes, HIDDEN_SIZE)).to(device) \n",
        "        self.by = torch.zeros((num_classes, 1), device=device)  \n",
        "\n",
        "        self.last_hs_standard = {}\n",
        "        \n",
        "        self.last_hs_modulated = {}\n",
        "\n",
        "        self.v_Wxh = torch.zeros_like(self.Wxh, device=device)\n",
        "        self.v_Whh = torch.zeros_like(self.Whh, device=device)\n",
        "        self.v_Why = torch.zeros_like(self.Why, device=device)\n",
        "\n",
        "\n",
        "\n",
        "    def forward(self, x, standard_pass):\n",
        "       \n",
        "        x = x.to(device)\n",
        "\n",
        "        hidden_values = {}\n",
        "        before = {}\n",
        "\n",
        "        # Initialize hidden state\n",
        "        h = torch.zeros((self.hidden_size, x.shape[0]), device=device)\n",
        "        hidden_values[0] = h\n",
        "       \n",
        "\n",
        "        for i in range(self.sequence_length):\n",
        "            z = self.Wxh @ x[:,i,:].T + self.Whh @ h + self.bh\n",
        "            h = torch.sigmoid(z)        \n",
        "            hidden_values[i + 1] = h\n",
        "        y_pred = self.Why @ h + self.by\n",
        "\n",
        "        if standard_pass:\n",
        "            self.last_hs_standard = hidden_values\n",
        "        else:\n",
        "            self.last_hs_modulated = hidden_values\n",
        "        return y_pred\n",
        "\n",
        "\n",
        "    def forward_target_prop(self, error, x):        \n",
        "        error = error.to(device)\n",
        "        x = x.to(device)\n",
        "\n",
        "        d_Why = -2* self.last_hs_standard[self.sequence_length] @ error\n",
        "        d_Why = d_Why.T\n",
        "        \n",
        "        d_Wxh = torch.zeros_like(self.Wxh, device=device)  \n",
        "        d_Whh = torch.zeros_like(self.Whh, device=device)  \n",
        "\n",
        "        for t in range(self.sequence_length): \n",
        "            d_Wxh += -(self.last_hs_standard[t + 1] - self.last_hs_modulated[t + 1]) *self.last_hs_standard[t + 1] *(1-self.last_hs_standard[t + 1]) @ x[:, t, :]\n",
        "            \n",
        "            d_Whh += -(self.last_hs_standard[t + 1] - self.last_hs_modulated[t + 1]) *self.last_hs_standard[t + 1]*(1-self.last_hs_standard[t + 1])  @ self.last_hs_standard[t].T\n",
        "        \n",
        "\n",
        "\n",
        "        # Update weights with momentum\n",
        "        mom=.9\n",
        "        self.v_Wxh = self.v_Wxh * mom+ self.learning_rate*d_Wxh\n",
        "        self.v_Whh = self.v_Whh * mom + self.learning_rate*d_Whh\n",
        "        self.v_Why = self.v_Why * mom + self.learning_rate*d_Why\n",
        "        \n",
        "\n",
        "        # Update weights\n",
        "        self.Wxh += self.v_Wxh\n",
        "        self.Whh += self.v_Whh\n",
        "        self.Why += self.v_Why\n",
        "\n",
        "    def xavier_initialization(self, shape):\n",
        "        fan_in, fan_out = shape[1], shape[0]\n",
        "        std = np.sqrt(2.0 / (fan_in + fan_out))\n",
        "        return torch.randn(*shape) * std\n",
        "\n",
        "    def reset_hidden_states(self):\n",
        "        self.last_hs_standard = {}\n",
        "        \n",
        "        self.last_hs_modulated = {}\n",
        "\n",
        "\n",
        "    def change_learning_rate(self):\n",
        "        self.learning_rate *= 0.5\n",
        "\n",
        "    def get_learning_rate(self):\n",
        "        return self.learning_rate\n",
        "\n",
        "    def save_model(self, file_path):\n",
        "        torch.save({\n",
        "            'Wxh': self.Wxh,\n",
        "            'Whh': self.Whh,\n",
        "            'Why': self.Why,\n",
        "            'bh': self.bh,\n",
        "            'by': self.by,\n",
        "            'input_size': self.input_size,\n",
        "            'hidden_size': self.hidden_size,\n",
        "            'sequence_length': self.sequence_length,\n",
        "        }, file_path)\n",
        "\n",
        "    def load_model(self, file_path):\n",
        "        checkpoint = torch.load(file_path)\n",
        "        self.Wxh = checkpoint['Wxh'].to(self.device)\n",
        "        self.Whh = checkpoint['Whh'].to(self.device)\n",
        "        self.Why = checkpoint['Why'].to(self.device)\n",
        "        self.bh = checkpoint['bh'].to(self.device)\n",
        "        self.by = checkpoint['by'].to(self.device)\n",
        "        self.input_size = checkpoint['input_size']\n",
        "        self.hidden_size = checkpoint['hidden_size']\n",
        "        self.sequence_length = checkpoint['sequence_length']\n",
        "        \n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Model hyperparameters\n",
        "num_layers = 1\n",
        "hidden_size = 512\n",
        "learning_rate = 0.00001\n",
        "\n",
        "\n",
        "# Initialize model and move it to the device (CUDA if available)\n",
        "model = simpleRNN(input_size, hidden_size, numclasses, sequence_length, learning_rate).to(device)\n",
        "\n",
        "loss_fn = nn.MSELoss()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Training the Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Initialize G matrix\n",
        "nin = input_size * sequence_length\n",
        "sd = np.sqrt(6 / nin)\n",
        "G = (torch.rand(nin, numclasses).to(device) * 2 * sd - sd) * 0.005  # Move G matrix to CUDA\n",
        "\n",
        "# Set random seed\n",
        "# torch.manual_seed(42)\n",
        "\n",
        "# Initialize lists to store accuracy and losses\n",
        "losses = []\n",
        "epochs = 500\n",
        "\n",
        "#  Training Loop\n",
        "for epoch in range(epochs):\n",
        "    print(f\"Epoch: {epoch}\\n-------\")\n",
        "    \n",
        "    # Change learning rate at specific epochs\n",
        "    if epoch in [300, 450]:\n",
        "        model.change_learning_rate()\n",
        "        print('lr decreased to', model.get_learning_rate())\n",
        "    \n",
        "    train_loss = 0\n",
        "\n",
        "    # Iterate over batches in the train dataloader\n",
        "    for batch, (X, y) in enumerate(train_loader):  \n",
        "        model.train()\n",
        "        if dataset == 'load_digits':\n",
        "            X = X.unsqueeze(1)\n",
        "        \n",
        "        x_data = X.clone()\n",
        "        x_mod = X.clone()\n",
        "\n",
        "\n",
        "        X = X.to(device) \n",
        "        y = y.to(device)\n",
        "        x_data = x_data.to(device) \n",
        "        x_mod = x_mod.to(device)  \n",
        "  \n",
        "        # Predict the output using the model and move predictions to the correct device\n",
        "        y_pred = model(X, True)\n",
        "        y_pred = y_pred.T  # Transpose the predictions\n",
        "\n",
        "        y = y.squeeze(1)\n",
        "        \n",
        "        error = y_pred - y\n",
        "\n",
        "        error_tau = -torch.tanh(y_pred@ G.T) + torch.tanh(y@ G.T)\n",
        "        \n",
        "        # Estimate target using the error signal\n",
        "        x_mod += (error_tau).reshape(-1,sequence_length,input_size) \n",
        "\n",
        "        y_mod = model(x_mod, False)\n",
        "        model.forward_target_prop(error, X)\n",
        "        model.reset_hidden_states()\n",
        "\n",
        "        # Calculate loss and add it to the training loss for this batch\n",
        "        loss = nn.functional.mse_loss(y_pred, y)\n",
        "        train_loss += loss\n",
        "        \n",
        "    # Divide total train loss by the number of batches\n",
        "    train_loss /= len(train_loader)\n",
        "    losses.append(train_loss)\n",
        "\n",
        "    # Print loss and accuracy for the current epoch\n",
        "    print(f'loss: {train_loss}')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Model Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import torch.nn.functional as F\n",
        "scale = normalization_params['max']\n",
        "scale = torch.tensor(scale, dtype=torch.float32).to(device)\n",
        "\n",
        "model.eval()\n",
        "rae_num = 0\n",
        "rse_num = 0\n",
        "n_samples = 0\n",
        "\n",
        "\n",
        "# Compute mean of ground truth for RAE and RSE denominator\n",
        "all_targets = []\n",
        "all_predictions = []\n",
        "\n",
        "with torch.no_grad():\n",
        "    for X_batch, y_batch in test_loader:\n",
        "        X_batch = X_batch.reshape(-1, 24, input_size).to(device)\n",
        "        y_batch = y_batch.to(device)\n",
        "        y_pred = model(X_batch, True)  # Shape: (batch_size, 321)\n",
        "        y_batch = y_batch.squeeze(1).T\n",
        "\n",
        "        # De-normalize\n",
        "        y_pred = y_pred * scale.unsqueeze(1)\n",
        "        y_batch = y_batch * scale.unsqueeze(1)\n",
        "\n",
        "        all_targets.append(y_batch)\n",
        "        all_predictions.append(y_pred)\n",
        "\n",
        "        # Compute MAE and MSE\n",
        "        batch_mae = F.l1_loss(y_pred, y_batch, reduction='sum')\n",
        "        batch_mse = F.mse_loss(y_pred, y_batch, reduction='sum')\n",
        "\n",
        "        mae_total += batch_mae.item()\n",
        "        mse_total += batch_mse.item()\n",
        "        n_samples += y_batch.size(1)\n",
        "\n",
        "# Stack all targets to compute mean(y) over the dataset\n",
        "all_targets = torch.cat(all_targets, dim=0)\n",
        "target_mean = torch.mean(all_targets)\n",
        "all_predictions = torch.cat(all_predictions, dim=0)\n",
        "\n",
        "\n",
        "# Re-run loop to compute RAE and RSE numerators\n",
        "with torch.no_grad():\n",
        "    for X_batch, y_batch in test_loader:\n",
        "        X_batch = X_batch.reshape(-1, sequence_length, input_size).to(device)\n",
        "        y_batch = y_batch.to(device)\n",
        "        y_pred = model(X_batch, True)\n",
        "        y_batch = y_batch.squeeze(1).T\n",
        "\n",
        "        # De-normalize\n",
        "        y_pred = y_pred * scale.unsqueeze(1)\n",
        "        y_batch = y_batch * scale.unsqueeze(1)\n",
        "\n",
        "        rae_num += torch.sum(torch.abs(y_pred - y_batch)).item()\n",
        "        rse_num += torch.sum((y_pred - y_batch) ** 2).item()\n",
        "\n",
        "# Compute RAE and RSE denominators\n",
        "rae_den = torch.sum(torch.abs(all_targets - target_mean)).item()\n",
        "rse_den = torch.sum((all_targets - target_mean) ** 2).item()\n",
        "\n",
        "# Final Metrics\n",
        "mae = mae_total / (n_samples * input_size)\n",
        "mse = mse_total / (n_samples * input_size)\n",
        "rae = np.sqrt(rae_num / rae_den)\n",
        "rse = np.sqrt(rse_num / rse_den)\n",
        "\n",
        "\n",
        "# Compute Empirical CORR (feature-wise)\n",
        "y_mean = torch.mean(all_targets, dim=0, keepdim=True)       \n",
        "y_pred_mean = torch.mean(all_predictions, dim=0, keepdim=True)\n",
        "\n",
        "corr_num = torch.sum((all_targets - y_mean) * (all_predictions - y_pred_mean), dim=0)\n",
        "corr_den_y = torch.sum((all_targets - y_mean) ** 2, dim=0)\n",
        "corr_den_pred = torch.sum((all_predictions - y_pred_mean) ** 2, dim=0)\n",
        "\n",
        "corr = torch.mean(corr_num / (torch.sqrt(corr_den_y) * torch.sqrt(corr_den_pred) + 1e-8))\n",
        "\n",
        "\n",
        "print(f\"Test MAE: {mae:.6f}\")\n",
        "print(f\"Test MSE: {mse:.6f}\")\n",
        "print(f\"Test RAE: {rae:.6f}\")\n",
        "print(f\"Test RSE: {rse:.6f}\")\n",
        "print(f\"Test Empirical CORR: {corr:.6f}\")\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
