{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9db76765-2e54-438f-a159-eb3e30dddb0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import Tensor\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b41b1021-0dd7-4b52-b682-fe7a3d731bf7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/ebuild/installs/software/Anaconda3/2024.02-1/lib/python3.11/site-packages/IPython/core/magics/pylab.py:162: UserWarning: pylab import has clobbered these variables: ['plt']\n",
      "`%matplotlib` prevents importing * from pylab and numpy\n",
      "  warn(\"pylab import has clobbered these variables: %s\"  % clobbered +\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib as plt\n",
    "%pylab inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "44d25799-b9ae-4433-b4ab-6cd46932738c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Split 2 (75% train): Train size = 8766, Test size = 2192\n",
      "Split 2 (75% train): Train size = 6574, Test size = 4384\n",
      "Split 3 (50% train): Train size = 5479, Test size = 5479\n",
      "Train X shape: torch.Size([6565, 10, 1]), y shape: torch.Size([6565, 10, 1])\n",
      "Test X shape: torch.Size([4375, 10, 1]), y shape: torch.Size([4375, 10, 1])\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "# Device and seed setup\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "\n",
    "# === Load dataset ===\n",
    "data = pd.read_csv(\"spatial_avg_data-ERAsinglelevel_2020-24_daily.csv\")\n",
    "\n",
    "# === Extract relevant features ===\n",
    "s_data = data[[\"asn\", \"rsn\", \"siconc\"]].values\n",
    "\n",
    "# === Train-test split ===\n",
    "#train_data = s_data[:5461]\n",
    "#test_data = s_data[5461:6050]\n",
    "# Split 1: Based on your original split (~90.26% train)\n",
    "\n",
    "# Split 2: ~75% training\n",
    "split_point_1 = int(len(s_data) * 0.80)\n",
    "train_data_1 = s_data[:split_point_1]\n",
    "test_data_1 = s_data[split_point_1:]\n",
    "print(f\"Split 2 (75% train): Train size = {len(train_data_1)}, Test size = {len(test_data_1)}\")\n",
    "\n",
    "# Split 2: ~75% training\n",
    "split_point_2 = int(len(s_data) * 0.60)\n",
    "train_data_2 = s_data[:split_point_2]\n",
    "test_data_2 = s_data[split_point_2:]\n",
    "print(f\"Split 2 (75% train): Train size = {len(train_data_2)}, Test size = {len(test_data_2)}\")\n",
    "\n",
    "# Split 3: ~50% training\n",
    "split_point_3 = int(len(s_data) * 0.50)\n",
    "train_data_3 = s_data[:split_point_3]\n",
    "test_data_3 = s_data[split_point_3:]\n",
    "print(f\"Split 3 (50% train): Train size = {len(train_data_3)}, Test size = {len(test_data_3)}\")\n",
    "train_data = train_data_2\n",
    "test_data = test_data_2 \n",
    "# === Scale features using only training stats ===\n",
    "scaler = StandardScaler()\n",
    "train_scaled = scaler.fit_transform(train_data)\n",
    "test_scaled = scaler.transform(test_data)\n",
    "\n",
    "# === Define constants for computing snow depth proxy ===\n",
    "A, B = 600, 300\n",
    "\n",
    "def compute_snow_depth_proxy(data):\n",
    "    siconc_term = (data[:, 2] * A) / (A - B)\n",
    "    albedo_density_term = (data[:, 0] * data[:, 1]) / (A - B)\n",
    "    return (siconc_term + albedo_density_term).reshape(-1, 1)\n",
    "\n",
    "# === Compute and normalize snow depth proxy ===\n",
    "train_depth = compute_snow_depth_proxy(train_scaled)\n",
    "test_depth = compute_snow_depth_proxy(test_scaled)\n",
    "\n",
    "mean_depth, std_depth = train_depth.mean(), train_depth.std()\n",
    "train_depth_norm = (train_depth - mean_depth) / std_depth\n",
    "test_depth_norm = (test_depth - mean_depth) / std_depth\n",
    "\n",
    "# === Convert to sequences ===\n",
    "def create_sequences(inputs, targets, seq_len):\n",
    "    X, Y = [], []\n",
    "    for i in range(len(inputs) - seq_len + 1):\n",
    "        X.append(inputs[i:i+seq_len])\n",
    "        Y.append(targets[i:i+seq_len])  # supervision at end of sequence\n",
    "    return np.array(X), np.array(Y)\n",
    "\n",
    "seq_length = 10\n",
    "X_train_seq, y_train_seq = create_sequences(train_scaled[:, 1:2], train_depth_norm, seq_length)  # only snow density as input\n",
    "X_test_seq, y_test_seq = create_sequences(test_scaled[:, 1:2], test_depth_norm, seq_length)\n",
    "\n",
    "# === Convert to PyTorch tensors ===\n",
    "X_train_tensor = torch.tensor(X_train_seq, dtype=torch.float32).to(device)\n",
    "y_train_tensor = torch.tensor(y_train_seq, dtype=torch.float32).to(device)\n",
    "X_test_tensor = torch.tensor(X_test_seq, dtype=torch.float32).to(device)\n",
    "y_test_tensor = torch.tensor(y_test_seq, dtype=torch.float32).to(device)\n",
    "\n",
    "# === Confirm shapes ===\n",
    "print(f\"Train X shape: {X_train_tensor.shape}, y shape: {y_train_tensor.shape}\")\n",
    "print(f\"Test X shape: {X_test_tensor.shape}, y shape: {y_test_tensor.shape}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "42a03fdb-73ed-42bb-a0d0-8202b96c4b23",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "\n",
    "# === Attention Block ===\n",
    "# === Multihead Self-Attention Wrapper ===\n",
    "class MultiHeadSelfAttention(nn.Module):\n",
    "    def __init__(self, hidden_dim, num_heads=2):\n",
    "        super(MultiHeadSelfAttention, self).__init__()\n",
    "        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, batch_first=True)\n",
    "        self.norm = nn.LayerNorm(hidden_dim)\n",
    "        self.dropout = nn.Dropout(0.2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        attn_output, _ = self.attn(x, x, x)\n",
    "        out = self.norm(x + self.dropout(attn_output))  # Residual + Norm\n",
    "        return out\n",
    "\n",
    "# === Bayesian LSTM + Multi-Head Attention ===\n",
    "class LSTMAttention(nn.Module):\n",
    "    def __init__(self, input_dim=1, hidden_dim=64, num_layers=2, output_dim=1, dropout_rate=0.4, num_heads=4):\n",
    "        super().__init__()\n",
    "        self.encoder_lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_rate)\n",
    "        self.attention = MultiHeadSelfAttention(hidden_dim, num_heads=num_heads)\n",
    "        self.decoder_lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_rate)\n",
    "\n",
    "        self.fc_depth = nn.Linear(hidden_dim, output_dim)\n",
    "        self.fc_params = nn.Sequential(\n",
    "            nn.Linear(hidden_dim, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(32, 16),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(16, 3)\n",
    "        )\n",
    "\n",
    "    def forward(self, x): # Only takes x\n",
    "        # === Original sequence\n",
    "        enc_x, (h_x, c_x) = self.encoder_lstm(x)\n",
    "        attn_x = self.attention(enc_x)\n",
    "        dec_x, _ = self.decoder_lstm(attn_x, (h_x, c_x))\n",
    "\n",
    "        # === Depth Prediction\n",
    "        depth_pred = self.fc_depth(dec_x).squeeze(-1)\n",
    "\n",
    "        # === Parameter Estimation\n",
    "        last_out = dec_x[:, -1, :]\n",
    "        params_pred = self.fc_params(last_out)\n",
    "        w_raw, b_raw, c_raw = params_pred[:, 0:1], params_pred[:, 1:2], params_pred[:, 2:3]\n",
    "\n",
    "        w = torch.sigmoid(w_raw) * 2.0 - 1.0\n",
    "        b = torch.exp(b_raw)\n",
    "        c = torch.tanh(c_raw) * 10.0\n",
    "\n",
    "        mean_density = x[:, :, 0].mean(dim=1, keepdim=True)\n",
    "        w_exp = w.expand_as(depth_pred)\n",
    "        b_exp = b.expand_as(depth_pred)\n",
    "        c_exp = c.expand_as(depth_pred)\n",
    "        mean_density_exp = mean_density.expand_as(depth_pred)\n",
    "\n",
    "        estimated_snow_depth = w_exp * mean_density_exp + b_exp * depth_pred + c_exp\n",
    "\n",
    "        return depth_pred, params_pred, estimated_snow_depth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3164e58b-3e95-49a4-b144-e04cc49d76de",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Let's use 2 GPUs!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/asampath/.local/lib/python3.11/site-packages/torch/nn/modules/loss.py:610: UserWarning: Using a target size (torch.Size([16, 1])) that is different to the input size (torch.Size([16, 10])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  return F.mse_loss(input, target, reduction=self.reduction)\n",
      "/tmp/ipykernel_321977/780006106.py:60: DeprecationWarning: Calling np.sum(generator) is deprecated, and in the future will give a different result. Use np.sum(np.fromiter(generator)) or the python sum builtin instead.\n",
      "  l1_norm = sum(p.abs().sum() for p in model.parameters())\n",
      "/tmp/ipykernel_321977/780006106.py:65: DeprecationWarning: Calling np.sum(generator) is deprecated, and in the future will give a different result. Use np.sum(np.fromiter(generator)) or the python sum builtin instead.\n",
      "  l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())\n",
      "/home/asampath/.local/lib/python3.11/site-packages/torch/nn/modules/loss.py:610: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5, 10])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
      "  return F.mse_loss(input, target, reduction=self.reduction)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/500, Loss: 5.1261\n",
      "Epoch 50/500, Loss: 0.9272\n",
      "Epoch 100/500, Loss: 0.8987\n",
      "Epoch 150/500, Loss: 0.8806\n",
      "Epoch 200/500, Loss: 0.8749\n",
      "Epoch 250/500, Loss: 0.8713\n",
      "Epoch 300/500, Loss: 0.8780\n",
      "Epoch 350/500, Loss: 0.8614\n",
      "Epoch 400/500, Loss: 0.8613\n",
      "Epoch 450/500, Loss: 0.8645\n",
      "(4375, 10)\n",
      "(4375, 10)\n",
      "Training results saved to training_results.pth\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# === Device Setup ===\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# === Initialize Model & Optimizer ===\n",
    "model = LSTMAttention().to(device)\n",
    "criterion_depth = nn.MSELoss()  # Loss for depth prediction\n",
    "criterion_estimated_snow_depth = nn.MSELoss()  # Loss for predicted\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.0005)\n",
    "\n",
    "# Check for multiple GPUs\n",
    "if torch.cuda.device_count() > 1:\n",
    "    print(f\"Let's use {torch.cuda.device_count()} GPUs!\")\n",
    "    model = nn.DataParallel(model)\n",
    "\n",
    "# === Training Setup ===\n",
    "epochs = 500\n",
    "batch_size = 16\n",
    "\n",
    "# Assume X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor are already defined\n",
    "# Convert data to PyTorch datasets\n",
    "train_dataset = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor)\n",
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=batch_size, shuffle=True\n",
    ")\n",
    "\n",
    "test_dataset = torch.utils.data.TensorDataset(X_test_tensor, y_test_tensor)\n",
    "test_dataloader = torch.utils.data.DataLoader(\n",
    "    test_dataset, batch_size=batch_size, shuffle=False\n",
    ")\n",
    "\n",
    "# === Store losses & predictions ===\n",
    "train_losses = []\n",
    "\n",
    "# === Training Loop ===\n",
    "for epoch in range(epochs):\n",
    "    total_loss = 0\n",
    "    model.train()  # Set to training mode\n",
    "\n",
    "    for batch_x, batch_y in train_dataloader:\n",
    "        batch_x, batch_y = batch_x.to(device), batch_y.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        # Removed augment_data call\n",
    "        # Correct Forward Pass\n",
    "        depth_pred, params_pred, estimated_snow_depth_pred = model(batch_x)\n",
    "        loss_depth = criterion_depth(depth_pred, batch_y[:, -1, :])\n",
    "        loss_estimated_snow_depth = criterion_depth(\n",
    "            estimated_snow_depth_pred, batch_y[:, -1, :]\n",
    "        )\n",
    "        loss = loss_depth + loss_estimated_snow_depth # Removed contrastive_loss_value\n",
    "\n",
    "        # L1 Regularization\n",
    "        l1_lambda = 0.001  # Adjust lambda as needed\n",
    "        l1_norm = sum(p.abs().sum() for p in model.parameters())\n",
    "        loss += l1_lambda * l1_norm\n",
    "\n",
    "        # L2 Regularization\n",
    "        l2_lambda = 1e-5  # adjust lambda as needed.\n",
    "        l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())\n",
    "        loss += l2_lambda * l2_norm\n",
    "\n",
    "        # Backpropagation\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    # Store loss for visualization\n",
    "    train_losses.append(total_loss / len(train_dataloader))\n",
    "\n",
    "    # Print Progress Every 50 Epochs\n",
    "    if epoch % 50 == 0:\n",
    "        print(f\"Epoch {epoch}/{epochs}, Loss: {train_losses[-1]:.4f}\")\n",
    "\n",
    "# === After Training: Evaluate on Test Set ===\n",
    "model.eval()\n",
    "depth_predictions = []\n",
    "estimated_depth_predictions = []\n",
    "with torch.no_grad():\n",
    "    for batch_x, _ in test_dataloader:\n",
    "        batch_x = batch_x.to(device)\n",
    "        # Removed augment_data call\n",
    "        depth_pred, params_pred, estimated_snow_depth = model(batch_x) # Unpack 3 values\n",
    "        depth_predictions.append(depth_pred.cpu())\n",
    "        estimated_depth_predictions.append(estimated_snow_depth.cpu())\n",
    "\n",
    "# Convert Predictions to NumPy\n",
    "estimated_depth_predictions_np = torch.cat(estimated_depth_predictions).squeeze().numpy()\n",
    "depth_predictions_np = torch.cat(depth_predictions).squeeze().numpy()\n",
    "true_depths_np = y_test_tensor.cpu().squeeze().numpy()\n",
    "print(depth_predictions_np.shape)\n",
    "print(true_depths_np.shape)\n",
    "avg_pred = depth_predictions_np.mean(axis=1)\n",
    "avg_est = estimated_depth_predictions_np.mean(axis=1)\n",
    "avg_true = true_depths_np.mean(axis=1)\n",
    "\n",
    "# === Save Training Results ===\n",
    "torch.save(\n",
    "    {\n",
    "        \"epoch\": epochs,\n",
    "        \"model_state_dict\": model.state_dict(),\n",
    "        \"optimizer_state_dict\": optimizer.state_dict(),\n",
    "        \"train_losses\": train_losses,\n",
    "        \"depth_predictions_np\": depth_predictions_np,\n",
    "        \"estimated_depth_predictions_np\": estimated_depth_predictions_np,\n",
    "        \"true_depths_np\": true_depths_np,\n",
    "    },\n",
    "    \"NoCL-Inv-training_results_sample1.pth\", # Changed filename\n",
    ")\n",
    "print(\"Training results saved to training_results.pth\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bd7b3cfb-1188-4819-a748-882a768be28a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Prediction:\n",
      "  MSE  = 0.6588\n",
      "  RMSE = 0.8117\n",
      "  Mean = -0.0699\n",
      "\n",
      "Physics-Refined Prediction:\n",
      "  MSE  = 0.6303\n",
      "  RMSE = 0.7939\n",
      "  Mean = -0.0612\n",
      "\n",
      "Ground Truth Mean:\n",
      "  Mean = -0.2554\n"
     ]
    }
   ],
   "source": [
    "# Arrays: shape (samples, timesteps)\n",
    "# depth_predictions_np         → LSTM prediction\n",
    "# estimated_depth_predictions_np → Physics-refined prediction\n",
    "# true_depths_np               → Ground truth\n",
    "\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "# === Overall metrics (flattened)\n",
    "mse = mean_squared_error(true_depths_np.flatten(), depth_predictions_np.flatten())\n",
    "rmse = np.sqrt(mse)\n",
    "mean = depth_predictions_np.mean()\n",
    "\n",
    "mse_phys = mean_squared_error(true_depths_np.flatten(), estimated_depth_predictions_np.flatten())\n",
    "rmse_phys = np.sqrt(mse_phys)\n",
    "mean_phys = estimated_depth_predictions_np.mean()\n",
    "\n",
    "mean_true = true_depths_np.mean()\n",
    "\n",
    "print(\" Prediction:\")\n",
    "print(f\"  MSE  = {mse:.4f}\")\n",
    "print(f\"  RMSE = {rmse:.4f}\")\n",
    "print(f\"  Mean = {mean:.4f}\\n\")\n",
    "\n",
    "print(\"Physics-Refined Prediction:\")\n",
    "print(f\"  MSE  = {mse_phys:.4f}\")\n",
    "print(f\"  RMSE = {rmse_phys:.4f}\")\n",
    "print(f\"  Mean = {mean_phys:.4f}\\n\")\n",
    "\n",
    "print(\"Ground Truth Mean:\")\n",
    "print(f\"  Mean = {mean_true:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26d4f738-295e-4105-8570-ee12acd0718a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
