{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "9db76765-2e54-438f-a159-eb3e30dddb0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import Tensor\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b41b1021-0dd7-4b52-b682-fe7a3d731bf7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/ebuild/installs/software/Anaconda3/2024.02-1/lib/python3.11/site-packages/IPython/core/magics/pylab.py:162: UserWarning: pylab import has clobbered these variables: ['plt', 'mean']\n",
      "`%matplotlib` prevents importing * from pylab and numpy\n",
      "  warn(\"pylab import has clobbered these variables: %s\"  % clobbered +\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib as plt\n",
    "%pylab inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "44d25799-b9ae-4433-b4ab-6cd46932738c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Split 2 (75% train): Train size = 8766, Test size = 2192\n",
      "Split 2 (75% train): Train size = 6574, Test size = 4384\n",
      "Split 3 (50% train): Train size = 5479, Test size = 5479\n",
      "Train X shape: torch.Size([8757, 10, 1]), y shape: torch.Size([8757, 10, 1])\n",
      "Test X shape: torch.Size([2183, 10, 1]), y shape: torch.Size([2183, 10, 1])\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "# Device and seed setup\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "\n",
    "# === Load dataset ===\n",
    "data = pd.read_csv(\"spatial_avg_data-ERAsinglelevel_2020-24_daily.csv\")\n",
    "\n",
    "# === Extract relevant features ===\n",
    "s_data = data[[\"asn\", \"rsn\", \"siconc\"]].values\n",
    "\n",
    "# === Train-test split ===\n",
    "#train_data = s_data[:5461]\n",
    "#test_data = s_data[5461:6050]\n",
    "\n",
    "# Split 2: ~75% training\n",
    "split_point_1 = int(len(s_data) * 0.80)\n",
    "train_data_1 = s_data[:split_point_1]\n",
    "test_data_1 = s_data[split_point_1:]\n",
    "print(f\"Split 2 (75% train): Train size = {len(train_data_1)}, Test size = {len(test_data_1)}\")\n",
    "\n",
    "# Split 2: ~75% training\n",
    "split_point_2 = int(len(s_data) * 0.60)\n",
    "train_data_2 = s_data[:split_point_2]\n",
    "test_data_2 = s_data[split_point_2:]\n",
    "print(f\"Split 2 (75% train): Train size = {len(train_data_2)}, Test size = {len(test_data_2)}\")\n",
    "\n",
    "# Split 3: ~50% training\n",
    "split_point_3 = int(len(s_data) * 0.50)\n",
    "train_data_3 = s_data[:split_point_3]\n",
    "test_data_3 = s_data[split_point_3:]\n",
    "print(f\"Split 3 (50% train): Train size = {len(train_data_3)}, Test size = {len(test_data_3)}\")\n",
    "train_data = train_data_1\n",
    "test_data = test_data_1 \n",
    "# === Scale features using only training stats ===\n",
    "scaler = StandardScaler()\n",
    "train_scaled = scaler.fit_transform(train_data)\n",
    "test_scaled = scaler.transform(test_data)\n",
    "\n",
    "# === Define constants for computing snow depth proxy ===\n",
    "A, B = 600, 300\n",
    "\n",
    "def compute_snow_depth_proxy(data):\n",
    "    siconc_term = (data[:, 2] * A) / (A - B)\n",
    "    albedo_density_term = (data[:, 0] * data[:, 1]) / (A - B)\n",
    "    return (siconc_term + albedo_density_term).reshape(-1, 1)\n",
    "\n",
    "# === Compute and normalize snow depth proxy ===\n",
    "train_depth = compute_snow_depth_proxy(train_scaled)\n",
    "test_depth = compute_snow_depth_proxy(test_scaled)\n",
    "\n",
    "mean_depth, std_depth = train_depth.mean(), train_depth.std()\n",
    "train_depth_norm = (train_depth - mean_depth) / std_depth\n",
    "test_depth_norm = (test_depth - mean_depth) / std_depth\n",
    "\n",
    "# === Convert to sequences ===\n",
    "def create_sequences(inputs, targets, seq_len):\n",
    "    X, Y = [], []\n",
    "    for i in range(len(inputs) - seq_len + 1):\n",
    "        X.append(inputs[i:i+seq_len])\n",
    "        Y.append(targets[i:i+seq_len])  # supervision at end of sequence\n",
    "    return np.array(X), np.array(Y)\n",
    "\n",
    "seq_length = 10\n",
    "X_train_seq, y_train_seq = create_sequences(train_scaled[:, 1:2], train_depth_norm, seq_length)  # only snow density as input\n",
    "X_test_seq, y_test_seq = create_sequences(test_scaled[:, 1:2], test_depth_norm, seq_length)\n",
    "\n",
    "# === Convert to PyTorch tensors ===\n",
    "X_train_tensor = torch.tensor(X_train_seq, dtype=torch.float32).to(device)\n",
    "y_train_tensor = torch.tensor(y_train_seq, dtype=torch.float32).to(device)\n",
    "X_test_tensor = torch.tensor(X_test_seq, dtype=torch.float32).to(device)\n",
    "y_test_tensor = torch.tensor(y_test_seq, dtype=torch.float32).to(device)\n",
    "\n",
    "# === Confirm shapes ===\n",
    "print(f\"Train X shape: {X_train_tensor.shape}, y shape: {y_train_tensor.shape}\")\n",
    "print(f\"Test X shape: {X_test_tensor.shape}, y shape: {y_test_tensor.shape}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "42a03fdb-73ed-42bb-a0d0-8202b96c4b23",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "\n",
    "# Data Augmentation Function for Contrastive Learning\n",
    "def augment_data(x, noise_std=0.01):\n",
    "\n",
    "    noise = torch.randn_like(x) * noise_std  # Add Gaussian noise\n",
    "    x_aug = x + noise\n",
    "    return x_aug\n",
    "\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def contrastive_loss(z_i, z_j, scale=0.05):\n",
    "\n",
    "    batch_size = z_i.size(0)\n",
    "\n",
    "    # Normalize embeddings\n",
    "    z_i = F.normalize(z_i, dim=1)\n",
    "    z_j = F.normalize(z_j, dim=1)\n",
    "\n",
    "    # Create a large similarity matrix\n",
    "    similarity_matrix = torch.matmul(z_i, z_j.T) / scale\n",
    "\n",
    "    #Create a similarity matrix that combines z_i and z_j\n",
    "    combined_z = torch.cat([z_i,z_j], dim=0)\n",
    "    combined_similarity = torch.matmul(combined_z, combined_z.T) / scale\n",
    "\n",
    "    #create labels.\n",
    "    labels = torch.arange(batch_size, device=z_i.device)\n",
    "    combined_labels = torch.cat([labels, labels], dim=0) #correct label concatination.\n",
    "\n",
    "    #remove diagonal.\n",
    "    mask = ~torch.eye(combined_labels.shape[0], device=combined_labels.device).bool()\n",
    "    combined_similarity = combined_similarity.masked_select(mask).view(combined_labels.shape[0], -1)\n",
    "    #Do not repeat labels.\n",
    "    #combined_labels = combined_labels.repeat(2)\n",
    "\n",
    "    loss = F.cross_entropy(combined_similarity, combined_labels)\n",
    "    return loss\n",
    "# === Attention Block ===\n",
    "# === Multihead Self-Attention Wrapper ===\n",
    "class MultiHeadSelfAttention(nn.Module):\n",
    "    def __init__(self, hidden_dim, num_heads=2):\n",
    "        super(MultiHeadSelfAttention, self).__init__()\n",
    "        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)\n",
    "        self.norm = nn.LayerNorm(hidden_dim)\n",
    "        self.dropout = nn.Dropout(0.2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        attn_output, _ = self.attn(x, x, x)\n",
    "        out = self.norm(x + self.dropout(attn_output))  # Residual + Norm\n",
    "        return out\n",
    "\n",
    "# === Bayesian LSTM + Multi-Head Attention ===\n",
    "class LSTMContrastiveWithAttention(nn.Module):\n",
    "    def __init__(self, input_dim=1, hidden_dim=64, num_layers=2, output_dim=1, dropout_rate=0.4, num_heads=4):\n",
    "        super().__init__()\n",
    "        self.encoder_lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_rate)\n",
    "        self.attention = MultiHeadSelfAttention(hidden_dim, num_heads=num_heads)\n",
    "        self.decoder_lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_rate)\n",
    "\n",
    "        self.fc_depth = nn.Linear(hidden_dim, output_dim)\n",
    "        self.fc_params = nn.Sequential(\n",
    "            nn.Linear(hidden_dim, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(32, 16),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(16, 3)\n",
    "        )\n",
    "\n",
    "    def forward(self, x, x_aug):\n",
    "        # === Original sequence\n",
    "        enc_x, (h_x, c_x) = self.encoder_lstm(x)\n",
    "        attn_x = self.attention(enc_x)\n",
    "        dec_x, _ = self.decoder_lstm(attn_x, (h_x, c_x))\n",
    "\n",
    "        # === Augmented sequence\n",
    "        enc_aug, (h_aug, c_aug) = self.encoder_lstm(x_aug)\n",
    "        attn_aug = self.attention(enc_aug)\n",
    "        dec_aug, _ = self.decoder_lstm(attn_aug, (h_aug, c_aug))\n",
    "\n",
    "        # === Depth Prediction\n",
    "        depth_pred = self.fc_depth(dec_x).squeeze(-1)\n",
    "\n",
    "        # === Parameter Estimation\n",
    "        last_out = dec_x[:, -1, :]\n",
    "        params_pred = self.fc_params(last_out)\n",
    "        w_raw, b_raw, c_raw = params_pred[:, 0:1], params_pred[:, 1:2], params_pred[:, 2:3]\n",
    "\n",
    "        w = torch.sigmoid(w_raw) * 2.0 - 1.0\n",
    "        b = torch.exp(b_raw)\n",
    "        c = torch.tanh(c_raw) * 10.0\n",
    "\n",
    "        mean_density = x[:, :, 0].mean(dim=1, keepdim=True)\n",
    "        w_exp = w.expand_as(depth_pred)\n",
    "        b_exp = b.expand_as(depth_pred)\n",
    "        c_exp = c.expand_as(depth_pred)\n",
    "        mean_density_exp = mean_density.expand_as(depth_pred)\n",
    "\n",
    "        estimated_snow_depth = w_exp * mean_density_exp + b_exp * depth_pred + c_exp\n",
    "\n",
    "        contrastive_loss_val = contrastive_loss(dec_x[:, -1, :], dec_aug[:, -1, :]).mean()\n",
    "\n",
    "        return depth_pred, params_pred, estimated_snow_depth, contrastive_loss_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "3164e58b-3e95-49a4-b144-e04cc49d76de",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Let's use 2 GPUs!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/asampath/.local/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:70: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn(\n",
      "/home/asampath/.local/lib/python3.11/site-packages/torch/nn/modules/loss.py:610: UserWarning: Using a target size (torch.Size([16, 1])) that is different to the input size (torch.Size([16, 10])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  return F.mse_loss(input, target, reduction=self.reduction)\n",
      "/tmp/ipykernel_564048/1371857881.py:72: DeprecationWarning: Calling np.sum(generator) is deprecated, and in the future will give a different result. Use np.sum(np.fromiter(generator)) or the python sum builtin instead.\n",
      "  l1_norm = sum(p.abs().sum() for p in model.parameters())\n",
      "/tmp/ipykernel_564048/1371857881.py:77: DeprecationWarning: Calling np.sum(generator) is deprecated, and in the future will give a different result. Use np.sum(np.fromiter(generator)) or the python sum builtin instead.\n",
      "  l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())\n",
      "/home/asampath/.local/lib/python3.11/site-packages/torch/nn/modules/loss.py:610: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5, 10])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  return F.mse_loss(input, target, reduction=self.reduction)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/500, Loss: 7.5525\n",
      "Epoch 50/500, Loss: 3.2492\n",
      "Epoch 100/500, Loss: 3.1997\n",
      "Epoch 150/500, Loss: 3.1527\n",
      "Epoch 200/500, Loss: 3.1362\n",
      "Epoch 250/500, Loss: 3.1249\n",
      "Epoch 300/500, Loss: 3.1168\n",
      "Epoch 350/500, Loss: 3.1173\n",
      "Epoch 400/500, Loss: 3.1128\n",
      "Epoch 450/500, Loss: 3.1176\n",
      "(2183, 10)\n",
      "(2183, 10)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# === Device Setup ===\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# === Initialize Model & Optimizer ===\n",
    "model = LSTMContrastiveWithAttention()  # Move model to device\n",
    "criterion_depth = nn.MSELoss()  # Loss for depth prediction\n",
    "criterion_estimated_snow_depth = nn.MSELoss()  # Loss for predicted\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.0005)  # removed weight_decay from here.\n",
    "\n",
    "# Check for multiple GPUs and wrap with DataParallel if available\n",
    "if torch.cuda.device_count() > 1:\n",
    "    print(f\"Let's use {torch.cuda.device_count()} GPUs!\")\n",
    "    model = nn.DataParallel(model)\n",
    "\n",
    "# Move model to device\n",
    "model = model.to(device)\n",
    "\n",
    "# === Training Setup ===\n",
    "epochs = 500\n",
    "batch_size = 16\n",
    "\n",
    "# Convert data to PyTorch datasets\n",
    "train_dataset = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor)\n",
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=batch_size, shuffle=True\n",
    ")\n",
    "\n",
    "test_dataset = torch.utils.data.TensorDataset(X_test_tensor, y_test_tensor)\n",
    "test_dataloader = torch.utils.data.DataLoader(\n",
    "    test_dataset, batch_size=batch_size, shuffle=False\n",
    ")\n",
    "\n",
    "# === Store losses & predictions ===\n",
    "train_losses = []\n",
    "\n",
    "# === Training Loop ===\n",
    "for epoch in range(epochs):\n",
    "    total_loss = 0\n",
    "    model.train()  # Set to training mode\n",
    "\n",
    "    for batch_x, batch_y in train_dataloader:\n",
    "        batch_x, batch_y = batch_x.to(device), batch_y.to(\n",
    "            device\n",
    "        )  # move batch to device.\n",
    "        optimizer.zero_grad()\n",
    "        batch_x_aug = augment_data(batch_x)\n",
    "        # Correct Forward Pass\n",
    "        (\n",
    "            depth_pred,\n",
    "            params_pred,\n",
    "            estimated_snow_depth_pred,\n",
    "            contrastive_loss_value,\n",
    "        ) = model(batch_x, batch_x_aug)\n",
    "        loss_depth = criterion_depth(depth_pred, batch_y[:, -1, :])\n",
    "        loss_estimated_snow_depth = criterion_depth(\n",
    "            estimated_snow_depth_pred, batch_y[:, -1, :]\n",
    "        )\n",
    "        contrastive_loss_value = torch.mean(\n",
    "            contrastive_loss_value\n",
    "        )  # Ensure scalar\n",
    "        loss = loss_depth + loss_estimated_snow_depth + contrastive_loss_value\n",
    "\n",
    "        # L1 Regularization\n",
    "        l1_lambda = 0.001  # Adjust lambda as needed\n",
    "        l1_norm = sum(p.abs().sum() for p in model.parameters())\n",
    "        loss += l1_lambda * l1_norm\n",
    "\n",
    "        # L2 Regularization\n",
    "        l2_lambda = 1e-5  # adjust lambda as needed.\n",
    "        l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())\n",
    "        loss += l2_lambda * l2_norm\n",
    "\n",
    "        # Backpropagation\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    # Store loss for visualization\n",
    "    train_losses.append(total_loss / len(train_dataloader))\n",
    "\n",
    "    # Print Progress Every 50 Epochs\n",
    "    if epoch % 50 == 0:\n",
    "        print(f\"Epoch {epoch}/{epochs}, Loss: {train_losses[-1]:.4f}\")\n",
    "\n",
    "# === After Training: Evaluate on Test Set ===\n",
    "model.eval()\n",
    "depth_predictions = []\n",
    "estimated_depth_predictions = []\n",
    "with torch.no_grad():\n",
    "    for batch_x, _ in test_dataloader:\n",
    "        batch_x = batch_x.to(device)\n",
    "        batch_x_aug = augment_data(\n",
    "            batch_x\n",
    "        )  # create augmented data.\n",
    "        (\n",
    "            depth_pred,\n",
    "            params_pred,\n",
    "            estimated_snow_depth,\n",
    "            contrastive_loss_value,\n",
    "        ) = model(\n",
    "            batch_x, batch_x_aug\n",
    "        )  # unpack 4 values.\n",
    "        depth_predictions.append(depth_pred.cpu())\n",
    "        estimated_depth_predictions.append(estimated_snow_depth.cpu())\n",
    "\n",
    "# Convert Predictions to NumPy\n",
    "estimated_depth_predictions_np = (\n",
    "    torch.cat(estimated_depth_predictions).squeeze().numpy()\n",
    ")\n",
    "depth_predictions_np = torch.cat(depth_predictions).squeeze().numpy()\n",
    "true_depths_np = y_test_tensor.cpu().squeeze().numpy()\n",
    "print(depth_predictions_np.shape)\n",
    "print(true_depths_np.shape)\n",
    "avg_pred = depth_predictions_np.mean(axis=1)\n",
    "avg_est = estimated_depth_predictions_np.mean(axis=1)\n",
    "avg_true = true_depths_np.mean(axis=1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "bd7b3cfb-1188-4819-a748-882a768be28a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  MSE  = 0.5926\n",
      "  RMSE = 0.7698\n"
     ]
    }
   ],
   "source": [
    "# Arrays: shape (samples, timesteps)\n",
    "# depth_predictions_np         → LSTM prediction\n",
    "# estimated_depth_predictions_np → Physics-refined prediction\n",
    "# true_depths_np               → Ground truth\n",
    "\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "# === Overall metrics (flattened)\n",
    "mse = mean_squared_error(true_depths_np.flatten(), depth_predictions_np.flatten())\n",
    "rmse = np.sqrt(mse)\n",
    "mean = depth_predictions_np.mean()\n",
    "\n",
    "mse_phys = mean_squared_error(true_depths_np.flatten(), estimated_depth_predictions_np.flatten())\n",
    "rmse_phys = np.sqrt(mse_phys)\n",
    "mean_phys = estimated_depth_predictions_np.mean()\n",
    "\n",
    "mean_true = true_depths_np.mean()\n",
    "\n",
    "print(f\"  MSE  = {mse_phys:.4f}\")\n",
    "print(f\"  RMSE = {rmse_phys:.4f}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
