{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch version:  1.13.0+cpu\n",
      "pyg version:  2.3.0.dev20230306\n",
      "batch.x =  tensor(indices=tensor([[0, 1, 1, 2, 3, 4],\n",
      "                       [1, 0, 1, 1, 0, 1]]),\n",
      "       values=tensor([1, 2, 3, 4, 5, 6]),\n",
      "       size=(5, 2), nnz=6, layout=torch.sparse_coo)\n",
      "Data(x=[2, 2])\n",
      "Data(x=[2, 2])\n",
      "[Data(x=[2, 2]), Data(x=[3, 2])]\n",
      "[tensor(indices=tensor([[0, 1, 1],\n",
      "                       [1, 0, 1]]),\n",
      "       values=tensor([1, 2, 3]),\n",
      "       size=(2, 2), nnz=3, layout=torch.sparse_coo), tensor(indices=tensor([[0, 1, 2],\n",
      "                       [1, 0, 1]]),\n",
      "       values=tensor([4, 5, 6]),\n",
      "       size=(3, 2), nnz=3, layout=torch.sparse_coo)]\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch_geometric\n",
    "from torch_geometric.data import Data, Batch\n",
    "\n",
    "data1 = Data(x=torch.sparse_coo_tensor(torch.tensor([[0, 1, 1], [1, 0, 1]]), torch.tensor([1, 2, 3]), (2, 2)))\n",
    "data2 = Data(x=torch.sparse_coo_tensor(torch.tensor([[0, 1, 2], [1, 0, 1]]), torch.tensor([4, 5, 6]), (3, 2)))\n",
    "batch = Batch.from_data_list([data1, data2])\n",
    "\n",
    "print(\"torch version: \", torch.__version__)\n",
    "print(\"pyg version: \", torch_geometric.__version__)\n",
    "print(\"batch.x = \", batch.x) # WORKS\n",
    "print(batch.get_example(0)) # FAILS\n",
    "print(batch[0]) # FAILS\n",
    "print(batch.to_data_list())\n",
    "print([b.x for b in batch.to_data_list()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(indices=tensor([[0, 1, 1, 2],\n",
       "                       [1, 0, 1, 0]]),\n",
       "       values=tensor([1, 2, 3, 5]),\n",
       "       size=(3, 2), nnz=4, layout=torch.sparse_coo)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.x.index_select(dim=0, index=torch.tensor([0, 1, 3]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graphium_ipu",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
