{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-07T12:15:50.918337Z",
     "start_time": "2025-09-07T12:15:48.173341Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "\n",
    "from slayer_model.kernels import Hat, MorletWavelet, DecayingExponentialKernel, Bump\n",
    "from slayer_model.layer import SRMLayer\n",
    "from slayer_model.network import SRMNetwork\n",
    "from datasets import DataLoader\n",
    "from slayer_model.utils import torch_to_cupy\n",
    "from utils.metrics import *\n",
    "from slayer_model.utils.losses import mse\n",
    "\n",
    "torch.set_default_dtype(torch.float32)"
   ],
   "id": "35a4b44ac1d50686",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No ROCm runtime is found, using ROCM_HOME='/opt/rocm'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Prophesee Dataset Toolbox could not be found!\n",
      "         Only Prophesee DVS demo will not run properly.\n",
      "         Please install it from https://github.com/prophesee-ai/prophesee-automotive-dataset-toolbox\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-07T12:15:57.062237Z",
     "start_time": "2025-09-07T12:15:50.922740Z"
    }
   },
   "cell_type": "code",
   "source": [
    "seed = 42\n",
    "rng = cp.random.default_rng(seed)\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "DATASET = 'electricity'\n",
    "H = 24\n",
    "\n",
    "data = DataLoader(dataset=DATASET, prediction_horizon=H, shuffle=True, seed=seed)\n",
    "device = 'cpu'"
   ],
   "id": "3c08f28c25e9c68e",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Configuration matches the last used one, and data files exist. Skipping data recreation.\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-07T12:15:57.115580Z",
     "start_time": "2025-09-07T12:15:57.113617Z"
    }
   },
   "cell_type": "code",
   "source": "transfer_func = lambda arr: torch.tensor(arr).to(device=device)",
   "id": "68bfdbc6c35ae235",
   "outputs": [],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-07T12:15:57.466715Z",
     "start_time": "2025-09-07T12:15:57.159694Z"
    }
   },
   "cell_type": "code",
   "source": "x,y = data.get_first_batch(batch_size=1000, target=\"train\", transfer_func=transfer_func)",
   "id": "65e39cc0fbf0ff6d",
   "outputs": [],
   "execution_count": 4
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-07T12:15:57.475390Z",
     "start_time": "2025-09-07T12:15:57.472632Z"
    }
   },
   "cell_type": "code",
   "source": [
    "n_hidden = 750\n",
    "\n",
    "dim, T = data.get_dim_t()\n",
    "fit_idcs = np.arange(T - H, T, dtype=int)\n",
    "\n",
    "dt = 1.\n",
    "tmax = T * dt\n",
    "\n",
    "k_centers = torch.rand(size=(n_hidden,)) * (35 * dt)\n",
    "k_widths = torch.rand(size=(n_hidden,)) * (35 * dt) + 5 * dt\n",
    "q_widths = 1 * k_widths"
   ],
   "id": "fc997234f284c495",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-07T12:15:57.593936Z",
     "start_time": "2025-09-07T12:15:57.518958Z"
    }
   },
   "cell_type": "code",
   "source": [
    "layers = [\n",
    "    {\n",
    "        'n_neurons': n_hidden,\n",
    "        'n_inputs': dim,\n",
    "        'phi_k': MorletWavelet(),\n",
    "        'phi_q': DecayingExponentialKernel(),\n",
    "        'dt':dt,\n",
    "        'k_centers': k_centers,\n",
    "        'k_widths': k_widths,\n",
    "        'q_widths': q_widths,\n",
    "    },\n",
    "    {\n",
    "        'n_neurons': dim,\n",
    "        'n_inputs': n_hidden,\n",
    "        'phi_k': Hat(),\n",
    "        'dt': dt,\n",
    "    },\n",
    "]\n",
    "\n",
    "net = SRMNetwork(layers)\n",
    "net = net.to(device)"
   ],
   "id": "311f20abb056c130",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/gm/PycharmProjects/sswim/slayer_model/utils/transfer.py:8: VisibleDeprecationWarning: This function is deprecated and will be removed in a future release. Use the cupy.from_dlpack() array constructor instead.\n",
      "  return torch.utils.dlpack.from_dlpack(tensor.toDlpack())\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-07T12:16:14.588792Z",
     "start_time": "2025-09-07T12:16:14.578303Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import os\n",
    "print(os.getcwd())\n",
    "net.load_state_dict(\n",
    "    torch.load(f\"../Experiments/sgd/results/{DATASET}_hat_{H}/weights_{seed}.pt\", map_location=device)\n",
    ")\n",
    "plt.show()"
   ],
   "id": "9ca42db7edcc190f",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/gm/PycharmProjects/sswim/Fit_Datasets_Slayer\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-07T12:16:55.602514Z",
     "start_time": "2025-09-07T12:16:17.412495Z"
    }
   },
   "cell_type": "code",
   "source": [
    "batch_size = 64\n",
    "initial_lr = 1e-4\n",
    "num_epochs = 500             # reuse your `its` variable as total epochs\n",
    "patience = 30                # early stopping tolerance (epochs)\n",
    "min_delta = 1e-6             # minimal decrease in val loss to count as improvement\n",
    "eta_min = 0.0                # final LR (can be >0 if you prefer)\n",
    "\n",
    "# --- Optimizer + Cosine scheduler ---\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=initial_lr)\n",
    "# CosineAnnealingLR: lr_t = eta_min + 0.5*(lr0 - eta_min)*(1 + cos(pi * t / T_max))\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=eta_min)\n",
    "\n",
    "# --- Early stopping bookkeeping ---\n",
    "best_val_loss = float(\"inf\")\n",
    "epochs_no_improve = 0\n",
    "best_state = None\n",
    "\n",
    "# Optional: put model on device if not already\n",
    "device = next(net.parameters()).device\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    net.train()\n",
    "    train_loss = 0.0\n",
    "    n_batches = data.get_n_batches(batch_size=batch_size, target=\"train\")\n",
    "    print(f\"Epoch {epoch+1}/{num_epochs} — lr: {optimizer.param_groups[0]['lr']:.3e}\")\n",
    "\n",
    "    for x, y in data.iterate(batch_size=batch_size, target=\"train\", transfer_func=transfer_func):\n",
    "        # Move to device if needed\n",
    "        if isinstance(x, torch.Tensor): x = x.to(device)\n",
    "        if isinstance(y, torch.Tensor): y = y.to(device)\n",
    "\n",
    "        # Forward pass (select fit indices)\n",
    "        output = net(x)[:, :, fit_idcs]\n",
    "\n",
    "        # Compute loss\n",
    "        loss = mse(output, y)\n",
    "        print(loss.item())\n",
    "        train_loss += loss.item()\n",
    "\n",
    "        # Backward + step\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    # Average train loss\n",
    "    avg_train_loss = train_loss / max(1, n_batches)\n",
    "    print(f\"Train Loss: {avg_train_loss:.6f}\")\n",
    "\n",
    "    # --- Validation ---\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "        x_val, y_val = data.get_first_batch(batch_size=-1, target=\"val\", transfer_func=transfer_func)\n",
    "        if isinstance(x_val, torch.Tensor): x_val = x_val.to(device)\n",
    "        if isinstance(y_val, torch.Tensor): y_val = y_val.to(device)\n",
    "        val_output = net(x_val)[:, :, fit_idcs]\n",
    "        val_loss = mse(val_output, y_val).item()\n",
    "        print(f\"Val Loss: {val_loss:.6f}\")\n",
    "\n",
    "    # --- Early stopping logic ---\n",
    "    if val_loss + min_delta < best_val_loss:\n",
    "        best_val_loss = val_loss\n",
    "        epochs_no_improve = 0\n",
    "        best_state = {k: v.cpu().clone() for k, v in net.state_dict().items()}  # keep CPU copy\n",
    "        print(f\"Validation loss improved; saving best model (val_loss={best_val_loss:.6f}).\")\n",
    "    else:\n",
    "        epochs_no_improve += 1\n",
    "        print(f\"No improvement for {epochs_no_improve} epoch(s).\")\n",
    "\n",
    "    if epochs_no_improve >= patience:\n",
    "        print(f\"Early stopping triggered after {epoch+1} epochs (no improvement for {patience} epochs).\")\n",
    "        break\n",
    "\n",
    "    # Step scheduler at epoch end\n",
    "    scheduler.step()\n",
    "\n",
    "# --- Restore best model weights (if any) ---\n",
    "if best_state is not None:\n",
    "    net.load_state_dict(best_state)\n",
    "    print(f\"Best model restored (val_loss={best_val_loss:.6f}).\")\n",
    "else:\n",
    "    print(\"No improvement observed during training; final model kept as-is.\")\n"
   ],
   "id": "bbc2f689ba3eec2d",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/500 — lr: 1.000e-04\n",
      "0.048859912902116776\n",
      "0.040966279804706573\n",
      "0.03998101130127907\n",
      "0.03659960255026817\n",
      "0.03527946397662163\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[9], line 42\u001B[0m\n\u001B[1;32m     40\u001B[0m     \u001B[38;5;66;03m# Backward + step\u001B[39;00m\n\u001B[1;32m     41\u001B[0m     optimizer\u001B[38;5;241m.\u001B[39mzero_grad()\n\u001B[0;32m---> 42\u001B[0m     \u001B[43mloss\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     43\u001B[0m     optimizer\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m     45\u001B[0m \u001B[38;5;66;03m# Average train loss\u001B[39;00m\n",
      "File \u001B[0;32m~/miniforge3/envs/sswim_lava/lib/python3.10/site-packages/torch/_tensor.py:626\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m    616\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m    617\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m    618\u001B[0m         Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m    619\u001B[0m         (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m    624\u001B[0m         inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m    625\u001B[0m     )\n\u001B[0;32m--> 626\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    627\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m    628\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/miniforge3/envs/sswim_lava/lib/python3.10/site-packages/torch/autograd/__init__.py:347\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m    342\u001B[0m     retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m    344\u001B[0m \u001B[38;5;66;03m# The reason we repeat the same comment below is that\u001B[39;00m\n\u001B[1;32m    345\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m    346\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 347\u001B[0m \u001B[43m_engine_run_backward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    348\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    349\u001B[0m \u001B[43m    \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    350\u001B[0m \u001B[43m    \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    351\u001B[0m \u001B[43m    \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    352\u001B[0m \u001B[43m    \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    353\u001B[0m \u001B[43m    \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\n\u001B[1;32m    354\u001B[0m \u001B[43m    \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\n\u001B[1;32m    355\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m~/miniforge3/envs/sswim_lava/lib/python3.10/site-packages/torch/autograd/graph.py:823\u001B[0m, in \u001B[0;36m_engine_run_backward\u001B[0;34m(t_outputs, *args, **kwargs)\u001B[0m\n\u001B[1;32m    821\u001B[0m     unregister_hooks \u001B[38;5;241m=\u001B[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001B[1;32m    822\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 823\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m  \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m    824\u001B[0m \u001B[43m        \u001B[49m\u001B[43mt_outputs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\n\u001B[1;32m    825\u001B[0m \u001B[43m    \u001B[49m\u001B[43m)\u001B[49m  \u001B[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001B[39;00m\n\u001B[1;32m    826\u001B[0m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[1;32m    827\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m attach_logging_hooks:\n",
      "File \u001B[0;32m~/miniforge3/envs/sswim_lava/lib/python3.10/site-packages/torch/autograd/function.py:292\u001B[0m, in \u001B[0;36mBackwardCFunction.apply\u001B[0;34m(self, *args)\u001B[0m\n\u001B[1;32m    287\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mBackwardCFunction\u001B[39;00m(_C\u001B[38;5;241m.\u001B[39m_FunctionBase, FunctionCtx, _HookMixin):\n\u001B[1;32m    288\u001B[0m \u001B[38;5;250m    \u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m    289\u001B[0m \u001B[38;5;124;03m    This class is used for internal autograd work. Do not use.\u001B[39;00m\n\u001B[1;32m    290\u001B[0m \u001B[38;5;124;03m    \"\"\"\u001B[39;00m\n\u001B[0;32m--> 292\u001B[0m     \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21mapply\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39margs):\n\u001B[1;32m    293\u001B[0m \u001B[38;5;250m        \u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m    294\u001B[0m \u001B[38;5;124;03m        Apply method used when executing this Node during the backward\u001B[39;00m\n\u001B[1;32m    295\u001B[0m \u001B[38;5;124;03m        \"\"\"\u001B[39;00m\n\u001B[1;32m    296\u001B[0m         \u001B[38;5;66;03m# _forward_cls is defined by derived class\u001B[39;00m\n\u001B[1;32m    297\u001B[0m         \u001B[38;5;66;03m# The user should define either backward or vjp but never both.\u001B[39;00m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "\"\"\"for i in range(its):\n",
    "    train_loss = 0\n",
    "    print(f\"Iter {i}:\")\n",
    "    for x,y in data.iterate(batch_size=batch_size, target=\"train\", transfer_func=transfer_func):\n",
    "        # Forward pass\n",
    "        output = net(x)[:, :, fit_idcs]\n",
    "\n",
    "        # Compute loss\n",
    "        loss = mse(output, y)\n",
    "        train_loss += loss.item()\n",
    "\n",
    "        # Backward pass\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    print(f\"Train Loss: {train_loss/(data.get_n_batches(batch_size=batch_size, target='train'))}\")\n",
    "    with torch.no_grad():\n",
    "        x,y = data.get_first_batch(batch_size=-1, target=\"val\", transfer_func=transfer_func)\n",
    "        output = net(x)[:, :, fit_idcs]\n",
    "        loss = mse(output, y)\n",
    "        print(f\"Val Loss: {loss.item():.4f}\")\n",
    "\"\"\""
   ],
   "id": "138834af7d006b1c",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-07T12:17:02.737305Z",
     "start_time": "2025-09-07T12:17:02.735277Z"
    }
   },
   "cell_type": "code",
   "source": [
    "r2a = R2Accumulator()\n",
    "rsea = RSEAccumulator()"
   ],
   "id": "a9c6d0b35f43480",
   "outputs": [],
   "execution_count": 10
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-07T12:17:15.351730Z",
     "start_time": "2025-09-07T12:17:07.958261Z"
    }
   },
   "cell_type": "code",
   "source": [
    "data_test = data.iterate(batch_size=500, target=\"test\", transfer_func=transfer_func)\n",
    "\n",
    "for x_test, y_test in data_test:\n",
    "    v_fit = net(x_test)[:, :, fit_idcs]\n",
    "    y_test = torch_to_cupy(y_test)\n",
    "    v_fit = torch_to_cupy(v_fit)\n",
    "    r2a.accumulate(y_test, v_fit)\n",
    "    rsea.accumulate(y_test, v_fit)\n"
   ],
   "id": "6d06f1ab0ede5ffb",
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "CPU arrays cannot be directly imported to CuPy. Use `cupy.array(numpy.from_dlpack(input))` instead.",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mTypeError\u001B[0m                                 Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[11], line 5\u001B[0m\n\u001B[1;32m      3\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m x_test, y_test \u001B[38;5;129;01min\u001B[39;00m data_test:\n\u001B[1;32m      4\u001B[0m     v_fit \u001B[38;5;241m=\u001B[39m net(x_test)[:, :, fit_idcs]\n\u001B[0;32m----> 5\u001B[0m     y_test \u001B[38;5;241m=\u001B[39m \u001B[43mtorch_to_cupy\u001B[49m\u001B[43m(\u001B[49m\u001B[43my_test\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m      6\u001B[0m     v_fit \u001B[38;5;241m=\u001B[39m torch_to_cupy(v_fit)\n\u001B[1;32m      7\u001B[0m     r2a\u001B[38;5;241m.\u001B[39maccumulate(y_test, v_fit)\n",
      "File \u001B[0;32m~/PycharmProjects/sswim/slayer_model/utils/transfer.py:5\u001B[0m, in \u001B[0;36mtorch_to_cupy\u001B[0;34m(tensor)\u001B[0m\n\u001B[1;32m      4\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21mtorch_to_cupy\u001B[39m(tensor: torch\u001B[38;5;241m.\u001B[39mTensor) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m cp\u001B[38;5;241m.\u001B[39mndarray:\n\u001B[0;32m----> 5\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mcp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfrom_dlpack\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtensor\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdetach\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32mcupy/_core/dlpack.pyx:579\u001B[0m, in \u001B[0;36mcupy._core.dlpack.from_dlpack\u001B[0;34m()\u001B[0m\n",
      "\u001B[0;31mTypeError\u001B[0m: CPU arrays cannot be directly imported to CuPy. Use `cupy.array(numpy.from_dlpack(input))` instead."
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "print(f\"R2 test: {r2a.reduce()}\")\n",
    "print(f\"RSE test: {rsea.reduce()}\")"
   ],
   "id": "86300ffd5cb66f05",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
