{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e7acdca6",
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "index 1104 is out of bounds for dimension 0 with size 1104",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_160345/2114188526.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     60\u001b[0m         \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     61\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m         \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0medge_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     63\u001b[0m         \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     64\u001b[0m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1499\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1500\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1502\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1503\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_160345/2114188526.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, edge_index)\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     31\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0medge_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0medge_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     33\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunctional\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1499\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1500\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1502\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1503\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch_geometric/nn/conv/gcn_conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, edge_index, edge_weight)\u001b[0m\n\u001b[1;32m    208\u001b[0m                 \u001b[0mcache\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cached_edge_index\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    209\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mcache\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 210\u001b[0;31m                     edge_index, edge_weight = gcn_norm(  # yapf: disable\n\u001b[0m\u001b[1;32m    211\u001b[0m                         \u001b[0medge_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0medge_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnode_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    212\u001b[0m                         self.improved, self.add_self_loops, self.flow, x.dtype)\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch_geometric/nn/conv/gcn_conv.py\u001b[0m in \u001b[0;36mgcn_norm\u001b[0;34m(edge_index, edge_weight, num_nodes, improved, add_self_loops, flow, dtype)\u001b[0m\n\u001b[1;32m     98\u001b[0m     \u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0medge_index\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0medge_index\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     99\u001b[0m     \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcol\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mflow\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'source_to_target'\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 100\u001b[0;31m     \u001b[0mdeg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0medge_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_nodes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'sum'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    101\u001b[0m     \u001b[0mdeg_inv_sqrt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdeg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpow_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m0.5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    102\u001b[0m     \u001b[0mdeg_inv_sqrt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmasked_fill_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdeg_inv_sqrt\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'inf'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch_geometric/utils/scatter.py\u001b[0m in \u001b[0;36mscatter\u001b[0;34m(src, index, dim, dim_size, reduce)\u001b[0m\n\u001b[1;32m     72\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'sum'\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'add'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     73\u001b[0m             \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbroadcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_zeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter_add_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msrc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'mean'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: index 1104 is out of bounds for dimension 0 with size 1104"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import networkx as nx\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import torch_geometric\n",
    "from torch_geometric.utils import to_networkx\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "cyclone_data = pd.read_csv('IO_out.csv')\n",
    "\n",
    "G = nx.Graph()\n",
    "\n",
    "for i in range(len(cyclone_data)):\n",
    "    G.add_node(i, vmax=cyclone_data['VMAX'][i])\n",
    "\n",
    "for i in range(0, len(cyclone_data)-1, 6):\n",
    "    for j in range(i+6, len(cyclone_data), 6):\n",
    "        G.add_edge(i, j, time=j-i)\n",
    "\n",
    "edge_index = edge_index = torch.tensor(list(G.edges())).t()\n",
    "\n",
    "class CycloneGNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CycloneGNN, self).__init__()\n",
    "        self.conv1 = torch_geometric.nn.GCNConv(1, 16)\n",
    "        self.conv2 = torch_geometric.nn.GCNConv(16, 1)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = nn.functional.relu(x)\n",
    "        x = nn.functional.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return x\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Cross-validation\n",
    "kf = KFold(n_splits=5, shuffle=True, random_state=42)\n",
    "for fold, (train_indices, val_indices) in enumerate(kf.split(data.x)):\n",
    "    # Create new Data objects for train and validation sets\n",
    "    train_data = torch_geometric.data.Data(\n",
    "        x=data.x[train_indices],\n",
    "        edge_index=data.edge_index,\n",
    "    )\n",
    "    val_data = torch_geometric.data.Data(\n",
    "        x=data.x[val_indices],\n",
    "        edge_index=data.edge_index,\n",
    "    )\n",
    "    \n",
    "    # Initialize the model and optimizer for each fold\n",
    "    model = CycloneGNN().to(device)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
    "    criterion = nn.MSELoss()\n",
    "    \n",
    "    # Training loop\n",
    "    for epoch in range(200):\n",
    "        model.train()\n",
    "        optimizer.zero_grad()\n",
    "        out = model(train_data.x.to(device), train_data.edge_index.to(device))\n",
    "        loss = criterion(out, train_data.x.to(device))\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        print(f'Fold: {fold+1}, Epoch: {epoch+1}, Loss: {loss.item():.4f}')\n",
    "    \n",
    "    # Validation\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        val_out = model(val_data.x.to(device), val_data.edge_index.to(device))\n",
    "        val_loss = criterion(val_out, val_data.x.to(device))\n",
    "        print(f'Fold: {fold+1}, Validation Loss: {val_loss.item():.4f}')\n",
    "\n",
    "print(\"Cross-validation completed.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10d67f11",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
